import torch |
import numpy as np |
import torchaudio |
import sentencepiece |
import logging |
from pathlib import Path |
from moshi.models import loaders, LMGen |
logging.basicConfig(level=logging.INFO) |
logger = logging.getLogger(__name__) |
class InferenceRecipe: |
"""Handles model inference for the Any-to-Any model.""" |
def __init__(self, model_path: str, device: str='cuda'): |
"""Initialize the model. |
Args: |
model_path (str): Path to model directory with pre-downloaded files |
device (str): Device to run on ('cuda' or 'cpu') |
""" |
self.device = torch.device(device) |
self.model_path = Path(model_path) |
self.sample_rate = 24000 |
self.frame_rate = 12.5 |
logger.info(f"Initializing models from {model_path}") |
self.mimi, self.text_tokenizer, self.lm_gen = self._initialize_models() |
self.mimi = self.mimi.to(self.device) |
self.lm_gen = self.lm_gen.to(self.device) |
logger.info("Model initialization complete") |
def _initialize_models(self): |
"""Initialize all required model components.""" |
print("Initializing models...") |
try: |
mimi_path = self.model_path / loaders.MIMI_NAME |
if not mimi_path.exists(): |
raise RuntimeError(f"MIMI model not found at {mimi_path}") |
logger.info(f"Loading MIMI model from {mimi_path}") |
mimi = loaders.get_mimi(str(mimi_path), device=self.device) |
mimi.set_num_codebooks(8) |
tokenizer_path = self.model_path / loaders.TEXT_TOKENIZER_NAME |
if not tokenizer_path.exists(): |
raise RuntimeError(f"Text tokenizer not found at {tokenizer_path}") |
logger.info(f"Loading text tokenizer from {tokenizer_path}") |
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer_path)) |
moshi_path = self.model_path / loaders.MOSHI_NAME |
if not moshi_path.exists(): |
raise RuntimeError(f"Language model not found at {moshi_path}") |
logger.info(f"Loading language model from {moshi_path}") |
moshi = loaders.get_moshi_lm(str(moshi_path), device=self.device) |
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) |
return mimi, text_tokenizer, lm_gen |
except Exception as e: |
logger.error(f"Model initialization failed: {str(e)}") |
raise |
def _load_audio(self, audio_array: np.ndarray, sample_rate: int): |
"""Load and preprocess audio.""" |
try: |
wav = torch.from_numpy(audio_array).float().unsqueeze(0) |
if sample_rate != self.sample_rate: |
logger.info(f"Resampling from {sample_rate} to {self.sample_rate}") |
resampler = torchaudio.transforms.Resample( |
orig_freq=sample_rate, |
new_freq=self.sample_rate |
).to(self.device) |
wav = resampler(wav.to(self.device)) |
else: |
wav = wav.to(self.device) |
frame_size = int(self.sample_rate / self.frame_rate) |
orig_length = wav.shape[-1] |
wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size] |
if wav.shape[-1] != orig_length: |
logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment") |
return wav |
except Exception as e: |
logger.error(f"Audio loading failed: {str(e)}") |
raise |
def _pad_codes(self, all_codes, time_seconds=30): |
try: |
min_frames = int(time_seconds * self.frame_rate) |
frame_size = int(self.sample_rate / self.frame_rate) |
if len(all_codes) < min_frames: |
frames_to_add = min_frames - len(all_codes) |
logger.info(f"Padding {frames_to_add} frames to reach minimum length") |
with torch.no_grad(), self.mimi.streaming(batch_size=1): |
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device) |
for _ in range(frames_to_add): |
additional_code = self.mimi.encode(chunk) |
all_codes.append(additional_code) |
return all_codes |
except Exception as e: |
logger.error(f"Code padding failed: {str(e)}") |
raise |
def _encode_audio(self, wav: torch.Tensor): |
"""Convert audio to codes.""" |
try: |
frame_size = int(self.sample_rate / self.frame_rate) |
all_codes = [] |
with torch.no_grad(), self.mimi.streaming(batch_size=1): |
for offset in range(0, wav.shape[-1], frame_size): |
frame = wav[:, :, offset: offset + frame_size] |
codes = self.mimi.encode(frame.to(self.device)) |
assert codes.shape[-1] == 1, f"Expected code shape (*, *, 1), got {codes.shape}" |
all_codes.append(codes) |
logger.info(f"Encoded {len(all_codes)} frames") |
return all_codes |
except Exception as e: |
logger.error(f"Audio encoding failed: {str(e)}") |
raise |
def _warmup(self): |
"""Run a warmup pass.""" |
try: |
frame_size = int(self.sample_rate / self.frame_rate) |
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device) |
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1): |
codes = self.mimi.encode(chunk) |
tokens = self.lm_gen.step(codes[:, :, 0:1]) |
if tokens is not None: |
_ = self.mimi.decode(tokens[:, 1:]) |
if self.device.type == 'cuda': |
torch.cuda.synchronize() |
logger.info("Warmup pass completed") |
except Exception as e: |
logger.error(f"Warmup failed: {str(e)}") |
raise |
def _generate(self, all_codes): |
"""Generate audio and text from codes.""" |
try: |
out_wav_chunks = [] |
text_output = [] |
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1): |
for i, code in enumerate(all_codes): |
assert code.shape == (1, 8, 1), f"Expected code shape (1, 8, 1), got {code.shape}" |
tokens_out = self.lm_gen.step(code.to(self.device)) |
if tokens_out is not None: |
wav_chunk = self.mimi.decode(tokens_out[:, 1:]) |
out_wav_chunks.append(wav_chunk) |
text_token = tokens_out[0, 0, 0].item() |
if text_token not in (0, 3): |
_text = self.text_tokenizer.id_to_piece(text_token) |
_text = _text.replace("▁", " ") |
text_output.append(_text) |
if (i + 1) % 100 == 0: |
logger.info(f"Processed {i + 1}/{len(all_codes)} frames") |
wav = torch.cat(out_wav_chunks, dim=-1) |
text = ''.join(text_output) |
logger.info(f"Generated {wav.shape[-1]} samples of audio and {len(text)} characters of text") |
return wav, text |
except Exception as e: |
logger.error(f"Generation failed: {str(e)}") |
raise |
def inference(self, audio_array: np.ndarray, sample_rate: int) -> dict: |
"""Run inference on input audio. |
Args: |
audio_array (np.ndarray): Input audio as numpy array |
sample_rate (int): Sample rate of input audio |
Returns: |
dict: Contains generated audio array and optional transcribed text |
""" |
try: |
logger.info(f"Starting inference on {len(audio_array)} samples at {sample_rate} Hz, self device: {self.device}") |
wav = self._load_audio(audio_array, sample_rate) |
wav = wav.to(self.device) |
all_codes = self._encode_audio(wav) |
all_codes = self._pad_codes(all_codes) |
self._warmup() |
out_wav, text = self._generate(all_codes) |
output = out_wav.cpu().numpy().squeeze() |
logger.info("Inference completed successfully") |
return { |
"audio": output, |
"text": text |
} |
except Exception as e: |
logger.error(f"Inference failed: {str(e)}") |
raise |
if __name__ == "__main__": |
import librosa |
model = InferenceRecipe("/path/to/models", device="cuda") |
audio, sr = librosa.load("test.wav", sr=None) |
result = model.inference(audio, sr) |
print(f"Generated {len(result['audio'])} samples of audio") |
print(f"Generated text: {result['text']}") |