|
|
|
|
|
|
|
"""Retrieves the pretrained models for Moshi and Mimi.""" |
|
from pathlib import Path |
|
|
|
from safetensors.torch import load_model |
|
import torch |
|
|
|
from moshi.models.compression import MimiModel |
|
from moshi.models.lm import LMModel |
|
from moshi.modules import SEANetEncoder, SEANetDecoder, transformer |
|
from moshi.quantization import SplitResidualVectorQuantizer |
|
|
|
SAMPLE_RATE = 24000 |
|
FRAME_RATE = 12.5 |
|
|
|
TEXT_TOKENIZER_NAME = 'tokenizer_spm_32k_3.model' |
|
MOSHI_NAME = 'model.safetensors' |
|
MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors' |
|
DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16' |
|
|
|
|
|
_seanet_kwargs = { |
|
"channels": 1, |
|
"dimension": 512, |
|
"causal": True, |
|
"n_filters": 64, |
|
"n_residual_layers": 1, |
|
"activation": "ELU", |
|
"compress": 2, |
|
"dilation_base": 2, |
|
"disable_norm_outer_blocks": 0, |
|
"kernel_size": 7, |
|
"residual_kernel_size": 3, |
|
"last_kernel_size": 3, |
|
|
|
|
|
"norm": "none", |
|
"pad_mode": "constant", |
|
"ratios": [8, 6, 5, 4], |
|
"true_skip": True, |
|
} |
|
_quantizer_kwargs = { |
|
"dimension": 256, |
|
"n_q": 32, |
|
"bins": 2048, |
|
"input_dimension": _seanet_kwargs["dimension"], |
|
"output_dimension": _seanet_kwargs["dimension"], |
|
} |
|
_transformer_kwargs = { |
|
"d_model": _seanet_kwargs["dimension"], |
|
"num_heads": 8, |
|
"num_layers": 8, |
|
"causal": True, |
|
"layer_scale": 0.01, |
|
"context": 250, |
|
"conv_layout": True, |
|
"max_period": 10000, |
|
"gating": "none", |
|
"norm": "layer_norm", |
|
"positional_embedding": "rope", |
|
"dim_feedforward": 2048, |
|
"input_dimension": _seanet_kwargs["dimension"], |
|
"output_dimensions": [_seanet_kwargs["dimension"]], |
|
} |
|
|
|
_lm_kwargs = { |
|
"dim": 4096, |
|
"text_card": 32000, |
|
"existing_text_padding_id": 3, |
|
"n_q": 16, |
|
"dep_q": 8, |
|
"card": _quantizer_kwargs["bins"], |
|
"num_heads": 32, |
|
"num_layers": 32, |
|
"hidden_scale": 4.125, |
|
"causal": True, |
|
"layer_scale": None, |
|
"context": 3000, |
|
"max_period": 10000, |
|
"gating": "silu", |
|
"norm": "rms_norm_f32", |
|
"positional_embedding": "rope", |
|
"depformer_dim": 1024, |
|
"depformer_dim_feedforward": int(4.125 * 1024), |
|
"depformer_num_heads": 16, |
|
"depformer_num_layers": 6, |
|
"depformer_causal": True, |
|
"depformer_layer_scale": None, |
|
"depformer_multi_linear": True, |
|
"depformer_context": 8, |
|
"depformer_max_period": 10000, |
|
"depformer_gating": "silu", |
|
"depformer_pos_emb": "none", |
|
"depformer_weights_per_step": True, |
|
"delays": [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1], |
|
} |
|
|
|
|
|
def _is_safetensors(path: Path | str) -> bool: |
|
return Path(path).suffix in (".safetensors", ".sft", ".sfts") |
|
|
|
|
|
def get_mimi(filename: str | Path, |
|
device: torch.device | str = 'cpu') -> MimiModel: |
|
"""Return a pretrained Mimi model.""" |
|
encoder = SEANetEncoder(**_seanet_kwargs) |
|
decoder = SEANetDecoder(**_seanet_kwargs) |
|
encoder_transformer = transformer.ProjectedTransformer( |
|
device=device, **_transformer_kwargs |
|
) |
|
decoder_transformer = transformer.ProjectedTransformer( |
|
device=device, **_transformer_kwargs |
|
) |
|
quantizer = SplitResidualVectorQuantizer( |
|
**_quantizer_kwargs, |
|
) |
|
model = MimiModel( |
|
encoder, |
|
decoder, |
|
quantizer, |
|
channels=1, |
|
sample_rate=SAMPLE_RATE, |
|
frame_rate=FRAME_RATE, |
|
encoder_frame_rate=SAMPLE_RATE / encoder.hop_length, |
|
causal=True, |
|
resample_method="conv", |
|
encoder_transformer=encoder_transformer, |
|
decoder_transformer=decoder_transformer, |
|
).to(device=device) |
|
model.eval() |
|
if _is_safetensors(filename): |
|
load_model(model, filename) |
|
else: |
|
pkg = torch.load(filename, "cpu") |
|
model.load_state_dict(pkg["model"]) |
|
model.set_num_codebooks(8) |
|
return model |
|
|
|
|
|
def get_moshi_lm(filename: str | Path, |
|
device: torch.device | str = 'cpu') -> LMModel: |
|
dtype = torch.bfloat16 |
|
model = LMModel( |
|
device=device, |
|
dtype=dtype, |
|
**_lm_kwargs, |
|
).to(device=device, dtype=dtype) |
|
model.eval() |
|
if _is_safetensors(filename): |
|
load_model(model, filename) |
|
else: |
|
pkg = torch.load( |
|
filename, |
|
"cpu", |
|
) |
|
model.load_state_dict(pkg["fsdp_best_state"]["model"]) |
|
return model |
|
|