|
import torch |
|
import numpy as np |
|
import torchaudio |
|
import sentencepiece |
|
import logging |
|
from pathlib import Path |
|
from moshi.models import loaders, LMGen |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class InferenceRecipe: |
|
"""Handles model inference for the Any-to-Any model.""" |
|
|
|
def __init__(self, model_path: str, device: str='cuda'): |
|
"""Initialize the model. |
|
|
|
Args: |
|
model_path (str): Path to model directory with pre-downloaded files |
|
device (str): Device to run on ('cuda' or 'cpu') |
|
""" |
|
self.device = torch.device(device) |
|
self.model_path = Path(model_path) |
|
|
|
|
|
self.sample_rate = 24000 |
|
self.frame_rate = 12.5 |
|
|
|
|
|
logger.info(f"Initializing models from {model_path}") |
|
self.mimi, self.text_tokenizer, self.lm_gen = self._initialize_models() |
|
self.mimi = self.mimi.to(self.device) |
|
self.lm_gen = self.lm_gen.to(self.device) |
|
logger.info("Model initialization complete") |
|
|
|
def _initialize_models(self): |
|
"""Initialize all required model components.""" |
|
print("Initializing models...") |
|
|
|
try: |
|
|
|
mimi_path = self.model_path / loaders.MIMI_NAME |
|
if not mimi_path.exists(): |
|
raise RuntimeError(f"MIMI model not found at {mimi_path}") |
|
logger.info(f"Loading MIMI model from {mimi_path}") |
|
mimi = loaders.get_mimi(str(mimi_path), device=self.device) |
|
mimi.set_num_codebooks(8) |
|
|
|
|
|
tokenizer_path = self.model_path / loaders.TEXT_TOKENIZER_NAME |
|
if not tokenizer_path.exists(): |
|
raise RuntimeError(f"Text tokenizer not found at {tokenizer_path}") |
|
logger.info(f"Loading text tokenizer from {tokenizer_path}") |
|
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer_path)) |
|
|
|
|
|
moshi_path = self.model_path / loaders.MOSHI_NAME |
|
if not moshi_path.exists(): |
|
raise RuntimeError(f"Language model not found at {moshi_path}") |
|
logger.info(f"Loading language model from {moshi_path}") |
|
moshi = loaders.get_moshi_lm(str(moshi_path), device=self.device) |
|
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) |
|
|
|
return mimi, text_tokenizer, lm_gen |
|
|
|
except Exception as e: |
|
logger.error(f"Model initialization failed: {str(e)}") |
|
raise |
|
|
|
def _load_audio(self, audio_array: np.ndarray, sample_rate: int): |
|
"""Load and preprocess audio.""" |
|
try: |
|
|
|
wav = torch.from_numpy(audio_array).float().unsqueeze(0) |
|
|
|
|
|
if sample_rate != self.sample_rate: |
|
logger.info(f"Resampling from {sample_rate} to {self.sample_rate}") |
|
|
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=sample_rate, |
|
new_freq=self.sample_rate |
|
).to(self.device) |
|
|
|
wav = resampler(wav.to(self.device)) |
|
else: |
|
|
|
wav = wav.to(self.device) |
|
|
|
|
|
frame_size = int(self.sample_rate / self.frame_rate) |
|
orig_length = wav.shape[-1] |
|
wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size] |
|
if wav.shape[-1] != orig_length: |
|
logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment") |
|
|
|
return wav |
|
|
|
except Exception as e: |
|
logger.error(f"Audio loading failed: {str(e)}") |
|
raise |
|
|
|
def _pad_codes(self, all_codes, time_seconds=30): |
|
try: |
|
min_frames = int(time_seconds * self.frame_rate) |
|
frame_size = int(self.sample_rate / self.frame_rate) |
|
|
|
if len(all_codes) < min_frames: |
|
frames_to_add = min_frames - len(all_codes) |
|
logger.info(f"Padding {frames_to_add} frames to reach minimum length") |
|
with torch.no_grad(), self.mimi.streaming(batch_size=1): |
|
|
|
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device) |
|
for _ in range(frames_to_add): |
|
additional_code = self.mimi.encode(chunk) |
|
all_codes.append(additional_code) |
|
|
|
return all_codes |
|
|
|
except Exception as e: |
|
logger.error(f"Code padding failed: {str(e)}") |
|
raise |
|
|
|
def _encode_audio(self, wav: torch.Tensor): |
|
"""Convert audio to codes.""" |
|
try: |
|
frame_size = int(self.sample_rate / self.frame_rate) |
|
all_codes = [] |
|
|
|
with torch.no_grad(), self.mimi.streaming(batch_size=1): |
|
for offset in range(0, wav.shape[-1], frame_size): |
|
frame = wav[:, :, offset: offset + frame_size] |
|
codes = self.mimi.encode(frame.to(self.device)) |
|
assert codes.shape[-1] == 1, f"Expected code shape (*, *, 1), got {codes.shape}" |
|
all_codes.append(codes) |
|
|
|
logger.info(f"Encoded {len(all_codes)} frames") |
|
return all_codes |
|
|
|
except Exception as e: |
|
logger.error(f"Audio encoding failed: {str(e)}") |
|
raise |
|
|
|
def _warmup(self): |
|
"""Run a warmup pass.""" |
|
try: |
|
frame_size = int(self.sample_rate / self.frame_rate) |
|
|
|
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device) |
|
|
|
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1): |
|
codes = self.mimi.encode(chunk) |
|
tokens = self.lm_gen.step(codes[:, :, 0:1]) |
|
if tokens is not None: |
|
_ = self.mimi.decode(tokens[:, 1:]) |
|
|
|
if self.device.type == 'cuda': |
|
torch.cuda.synchronize() |
|
logger.info("Warmup pass completed") |
|
|
|
except Exception as e: |
|
logger.error(f"Warmup failed: {str(e)}") |
|
raise |
|
|
|
def _generate(self, all_codes): |
|
"""Generate audio and text from codes.""" |
|
try: |
|
out_wav_chunks = [] |
|
text_output = [] |
|
|
|
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1): |
|
for i, code in enumerate(all_codes): |
|
assert code.shape == (1, 8, 1), f"Expected code shape (1, 8, 1), got {code.shape}" |
|
tokens_out = self.lm_gen.step(code.to(self.device)) |
|
|
|
if tokens_out is not None: |
|
|
|
wav_chunk = self.mimi.decode(tokens_out[:, 1:]) |
|
out_wav_chunks.append(wav_chunk) |
|
|
|
|
|
text_token = tokens_out[0, 0, 0].item() |
|
if text_token not in (0, 3): |
|
_text = self.text_tokenizer.id_to_piece(text_token) |
|
_text = _text.replace("▁", " ") |
|
text_output.append(_text) |
|
|
|
if (i + 1) % 100 == 0: |
|
logger.info(f"Processed {i + 1}/{len(all_codes)} frames") |
|
|
|
wav = torch.cat(out_wav_chunks, dim=-1) |
|
text = ''.join(text_output) |
|
|
|
logger.info(f"Generated {wav.shape[-1]} samples of audio and {len(text)} characters of text") |
|
return wav, text |
|
|
|
except Exception as e: |
|
logger.error(f"Generation failed: {str(e)}") |
|
raise |
|
|
|
def inference(self, audio_array: np.ndarray, sample_rate: int) -> dict: |
|
"""Run inference on input audio. |
|
|
|
Args: |
|
audio_array (np.ndarray): Input audio as numpy array |
|
sample_rate (int): Sample rate of input audio |
|
|
|
Returns: |
|
dict: Contains generated audio array and optional transcribed text |
|
""" |
|
try: |
|
logger.info(f"Starting inference on {len(audio_array)} samples at {sample_rate} Hz, self device: {self.device}") |
|
|
|
|
|
wav = self._load_audio(audio_array, sample_rate) |
|
wav = wav.to(self.device) |
|
|
|
|
|
all_codes = self._encode_audio(wav) |
|
all_codes = self._pad_codes(all_codes) |
|
|
|
|
|
self._warmup() |
|
|
|
|
|
out_wav, text = self._generate(all_codes) |
|
|
|
|
|
output = out_wav.cpu().numpy().squeeze() |
|
|
|
logger.info("Inference completed successfully") |
|
return { |
|
"audio": output, |
|
"text": text |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Inference failed: {str(e)}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
|
|
import librosa |
|
|
|
|
|
model = InferenceRecipe("/path/to/models", device="cuda") |
|
|
|
|
|
audio, sr = librosa.load("test.wav", sr=None) |
|
|
|
|
|
result = model.inference(audio, sr) |
|
print(f"Generated {len(result['audio'])} samples of audio") |
|
print(f"Generated text: {result['text']}") |