moshi_general / inference.py
tezuesh's picture
Update inference.py
5acce69 verified
raw
history blame
10.2 kB
import torch
import numpy as np
import torchaudio
import sentencepiece
import logging
from pathlib import Path
from moshi.models import loaders, LMGen
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class InferenceRecipe:
"""Handles model inference for the Any-to-Any model."""
def __init__(self, model_path: str, device: str='cuda'):
"""Initialize the model.
Args:
model_path (str): Path to model directory with pre-downloaded files
device (str): Device to run on ('cuda' or 'cpu')
"""
self.device = torch.device(device)
self.model_path = Path(model_path)
# Set sample rate and frame rate
self.sample_rate = 24000 # Based on model config in loaders.py
self.frame_rate = 12.5 # Based on model config in loaders.py
# Initialize all model components
logger.info(f"Initializing models from {model_path}")
self.mimi, self.text_tokenizer, self.lm_gen = self._initialize_models()
self.mimi = self.mimi.to(self.device)
self.lm_gen = self.lm_gen.to(self.device)
logger.info("Model initialization complete")
def _initialize_models(self):
"""Initialize all required model components."""
print("Initializing models...")
try:
# Load MIMI model for encoding/decoding
mimi_path = self.model_path / loaders.MIMI_NAME
if not mimi_path.exists():
raise RuntimeError(f"MIMI model not found at {mimi_path}")
logger.info(f"Loading MIMI model from {mimi_path}")
mimi = loaders.get_mimi(str(mimi_path), device=self.device)
mimi.set_num_codebooks(8)
# Load text tokenizer
tokenizer_path = self.model_path / loaders.TEXT_TOKENIZER_NAME
if not tokenizer_path.exists():
raise RuntimeError(f"Text tokenizer not found at {tokenizer_path}")
logger.info(f"Loading text tokenizer from {tokenizer_path}")
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer_path))
# Load language model
moshi_path = self.model_path / loaders.MOSHI_NAME
if not moshi_path.exists():
raise RuntimeError(f"Language model not found at {moshi_path}")
logger.info(f"Loading language model from {moshi_path}")
moshi = loaders.get_moshi_lm(str(moshi_path), device=self.device)
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
return mimi, text_tokenizer, lm_gen
except Exception as e:
logger.error(f"Model initialization failed: {str(e)}")
raise
def _load_audio(self, audio_array: np.ndarray, sample_rate: int):
"""Load and preprocess audio."""
try:
# Convert to tensor
wav = torch.from_numpy(audio_array).float().unsqueeze(0)
# Resample if needed
if sample_rate != self.sample_rate:
logger.info(f"Resampling from {sample_rate} to {self.sample_rate}")
# Create resampler on same device as input will be
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=self.sample_rate
).to(self.device)
# Move wav to device before resampling
wav = resampler(wav.to(self.device))
else:
# If no resampling needed, still ensure wav is on correct device
wav = wav.to(self.device)
# Ensure frame alignment
frame_size = int(self.sample_rate / self.frame_rate)
orig_length = wav.shape[-1]
wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size]
if wav.shape[-1] != orig_length:
logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment")
return wav
except Exception as e:
logger.error(f"Audio loading failed: {str(e)}")
raise
def _pad_codes(self, all_codes, time_seconds=30):
try:
min_frames = int(time_seconds * self.frame_rate)
frame_size = int(self.sample_rate / self.frame_rate)
if len(all_codes) < min_frames:
frames_to_add = min_frames - len(all_codes)
logger.info(f"Padding {frames_to_add} frames to reach minimum length")
with torch.no_grad(), self.mimi.streaming(batch_size=1):
# Create tensor on the correct device
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
for _ in range(frames_to_add):
additional_code = self.mimi.encode(chunk)
all_codes.append(additional_code)
return all_codes
except Exception as e:
logger.error(f"Code padding failed: {str(e)}")
raise
def _encode_audio(self, wav: torch.Tensor):
"""Convert audio to codes."""
try:
frame_size = int(self.sample_rate / self.frame_rate)
all_codes = []
with torch.no_grad(), self.mimi.streaming(batch_size=1):
for offset in range(0, wav.shape[-1], frame_size):
frame = wav[:, :, offset: offset + frame_size]
codes = self.mimi.encode(frame.to(self.device))
assert codes.shape[-1] == 1, f"Expected code shape (*, *, 1), got {codes.shape}"
all_codes.append(codes)
logger.info(f"Encoded {len(all_codes)} frames")
return all_codes
except Exception as e:
logger.error(f"Audio encoding failed: {str(e)}")
raise
def _warmup(self):
"""Run a warmup pass."""
try:
frame_size = int(self.sample_rate / self.frame_rate)
# Create tensor on the correct device from the start
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
codes = self.mimi.encode(chunk) # chunk already on correct device
tokens = self.lm_gen.step(codes[:, :, 0:1])
if tokens is not None:
_ = self.mimi.decode(tokens[:, 1:])
if self.device.type == 'cuda':
torch.cuda.synchronize()
logger.info("Warmup pass completed")
except Exception as e:
logger.error(f"Warmup failed: {str(e)}")
raise
def _generate(self, all_codes):
"""Generate audio and text from codes."""
try:
out_wav_chunks = []
text_output = []
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
for i, code in enumerate(all_codes):
assert code.shape == (1, 8, 1), f"Expected code shape (1, 8, 1), got {code.shape}"
tokens_out = self.lm_gen.step(code.to(self.device))
if tokens_out is not None:
# Generate audio
wav_chunk = self.mimi.decode(tokens_out[:, 1:])
out_wav_chunks.append(wav_chunk)
# Generate text if available
text_token = tokens_out[0, 0, 0].item()
if text_token not in (0, 3):
_text = self.text_tokenizer.id_to_piece(text_token)
_text = _text.replace("▁", " ")
text_output.append(_text)
if (i + 1) % 100 == 0:
logger.info(f"Processed {i + 1}/{len(all_codes)} frames")
wav = torch.cat(out_wav_chunks, dim=-1)
text = ''.join(text_output)
logger.info(f"Generated {wav.shape[-1]} samples of audio and {len(text)} characters of text")
return wav, text
except Exception as e:
logger.error(f"Generation failed: {str(e)}")
raise
def inference(self, audio_array: np.ndarray, sample_rate: int) -> dict:
"""Run inference on input audio.
Args:
audio_array (np.ndarray): Input audio as numpy array
sample_rate (int): Sample rate of input audio
Returns:
dict: Contains generated audio array and optional transcribed text
"""
try:
logger.info(f"Starting inference on {len(audio_array)} samples at {sample_rate} Hz, self device: {self.device}")
# Load and preprocess audio
wav = self._load_audio(audio_array, sample_rate)
wav = wav.to(self.device)
# Convert to codes
all_codes = self._encode_audio(wav)
all_codes = self._pad_codes(all_codes)
# Warmup pass
self._warmup()
# Generate output
out_wav, text = self._generate(all_codes)
# Convert output to numpy
output = out_wav.cpu().numpy().squeeze()
logger.info("Inference completed successfully")
return {
"audio": output,
"text": text
}
except Exception as e:
logger.error(f"Inference failed: {str(e)}")
raise
if __name__ == "__main__":
# Example usage
import librosa
# Initialize model
model = InferenceRecipe("/path/to/models", device="cuda")
# Load test audio
audio, sr = librosa.load("test.wav", sr=None)
# Run inference
result = model.inference(audio, sr)
print(f"Generated {len(result['audio'])} samples of audio")
print(f"Generated text: {result['text']}")