File size: 10,232 Bytes
22d5f88 7542ba5 22d5f88 5acce69 22d5f88 5acce69 22d5f88 5acce69 22d5f88 5acce69 22d5f88 5acce69 22d5f88 76b1e32 22d5f88 76b1e32 22d5f88 76b1e32 22d5f88 76b1e32 22d5f88 378e3c8 76b1e32 22d5f88 378e3c8 22d5f88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
import torch
import numpy as np
import torchaudio
import sentencepiece
import logging
from pathlib import Path
from moshi.models import loaders, LMGen
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class InferenceRecipe:
"""Handles model inference for the Any-to-Any model."""
def __init__(self, model_path: str, device: str='cuda'):
"""Initialize the model.
Args:
model_path (str): Path to model directory with pre-downloaded files
device (str): Device to run on ('cuda' or 'cpu')
"""
self.device = torch.device(device)
self.model_path = Path(model_path)
# Set sample rate and frame rate
self.sample_rate = 24000 # Based on model config in loaders.py
self.frame_rate = 12.5 # Based on model config in loaders.py
# Initialize all model components
logger.info(f"Initializing models from {model_path}")
self.mimi, self.text_tokenizer, self.lm_gen = self._initialize_models()
self.mimi = self.mimi.to(self.device)
self.lm_gen = self.lm_gen.to(self.device)
logger.info("Model initialization complete")
def _initialize_models(self):
"""Initialize all required model components."""
print("Initializing models...")
try:
# Load MIMI model for encoding/decoding
mimi_path = self.model_path / loaders.MIMI_NAME
if not mimi_path.exists():
raise RuntimeError(f"MIMI model not found at {mimi_path}")
logger.info(f"Loading MIMI model from {mimi_path}")
mimi = loaders.get_mimi(str(mimi_path), device=self.device)
mimi.set_num_codebooks(8)
# Load text tokenizer
tokenizer_path = self.model_path / loaders.TEXT_TOKENIZER_NAME
if not tokenizer_path.exists():
raise RuntimeError(f"Text tokenizer not found at {tokenizer_path}")
logger.info(f"Loading text tokenizer from {tokenizer_path}")
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer_path))
# Load language model
moshi_path = self.model_path / loaders.MOSHI_NAME
if not moshi_path.exists():
raise RuntimeError(f"Language model not found at {moshi_path}")
logger.info(f"Loading language model from {moshi_path}")
moshi = loaders.get_moshi_lm(str(moshi_path), device=self.device)
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
return mimi, text_tokenizer, lm_gen
except Exception as e:
logger.error(f"Model initialization failed: {str(e)}")
raise
def _load_audio(self, audio_array: np.ndarray, sample_rate: int):
"""Load and preprocess audio."""
try:
# Convert to tensor
wav = torch.from_numpy(audio_array).float().unsqueeze(0)
# Resample if needed
if sample_rate != self.sample_rate:
logger.info(f"Resampling from {sample_rate} to {self.sample_rate}")
# Create resampler on same device as input will be
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=self.sample_rate
).to(self.device)
# Move wav to device before resampling
wav = resampler(wav.to(self.device))
else:
# If no resampling needed, still ensure wav is on correct device
wav = wav.to(self.device)
# Ensure frame alignment
frame_size = int(self.sample_rate / self.frame_rate)
orig_length = wav.shape[-1]
wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size]
if wav.shape[-1] != orig_length:
logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment")
return wav
except Exception as e:
logger.error(f"Audio loading failed: {str(e)}")
raise
def _pad_codes(self, all_codes, time_seconds=30):
try:
min_frames = int(time_seconds * self.frame_rate)
frame_size = int(self.sample_rate / self.frame_rate)
if len(all_codes) < min_frames:
frames_to_add = min_frames - len(all_codes)
logger.info(f"Padding {frames_to_add} frames to reach minimum length")
with torch.no_grad(), self.mimi.streaming(batch_size=1):
# Create tensor on the correct device
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
for _ in range(frames_to_add):
additional_code = self.mimi.encode(chunk)
all_codes.append(additional_code)
return all_codes
except Exception as e:
logger.error(f"Code padding failed: {str(e)}")
raise
def _encode_audio(self, wav: torch.Tensor):
"""Convert audio to codes."""
try:
frame_size = int(self.sample_rate / self.frame_rate)
all_codes = []
with torch.no_grad(), self.mimi.streaming(batch_size=1):
for offset in range(0, wav.shape[-1], frame_size):
frame = wav[:, :, offset: offset + frame_size]
codes = self.mimi.encode(frame.to(self.device))
assert codes.shape[-1] == 1, f"Expected code shape (*, *, 1), got {codes.shape}"
all_codes.append(codes)
logger.info(f"Encoded {len(all_codes)} frames")
return all_codes
except Exception as e:
logger.error(f"Audio encoding failed: {str(e)}")
raise
def _warmup(self):
"""Run a warmup pass."""
try:
frame_size = int(self.sample_rate / self.frame_rate)
# Create tensor on the correct device from the start
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
codes = self.mimi.encode(chunk) # chunk already on correct device
tokens = self.lm_gen.step(codes[:, :, 0:1])
if tokens is not None:
_ = self.mimi.decode(tokens[:, 1:])
if self.device.type == 'cuda':
torch.cuda.synchronize()
logger.info("Warmup pass completed")
except Exception as e:
logger.error(f"Warmup failed: {str(e)}")
raise
def _generate(self, all_codes):
"""Generate audio and text from codes."""
try:
out_wav_chunks = []
text_output = []
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
for i, code in enumerate(all_codes):
assert code.shape == (1, 8, 1), f"Expected code shape (1, 8, 1), got {code.shape}"
tokens_out = self.lm_gen.step(code.to(self.device))
if tokens_out is not None:
# Generate audio
wav_chunk = self.mimi.decode(tokens_out[:, 1:])
out_wav_chunks.append(wav_chunk)
# Generate text if available
text_token = tokens_out[0, 0, 0].item()
if text_token not in (0, 3):
_text = self.text_tokenizer.id_to_piece(text_token)
_text = _text.replace("▁", " ")
text_output.append(_text)
if (i + 1) % 100 == 0:
logger.info(f"Processed {i + 1}/{len(all_codes)} frames")
wav = torch.cat(out_wav_chunks, dim=-1)
text = ''.join(text_output)
logger.info(f"Generated {wav.shape[-1]} samples of audio and {len(text)} characters of text")
return wav, text
except Exception as e:
logger.error(f"Generation failed: {str(e)}")
raise
def inference(self, audio_array: np.ndarray, sample_rate: int) -> dict:
"""Run inference on input audio.
Args:
audio_array (np.ndarray): Input audio as numpy array
sample_rate (int): Sample rate of input audio
Returns:
dict: Contains generated audio array and optional transcribed text
"""
try:
logger.info(f"Starting inference on {len(audio_array)} samples at {sample_rate} Hz, self device: {self.device}")
# Load and preprocess audio
wav = self._load_audio(audio_array, sample_rate)
wav = wav.to(self.device)
# Convert to codes
all_codes = self._encode_audio(wav)
all_codes = self._pad_codes(all_codes)
# Warmup pass
self._warmup()
# Generate output
out_wav, text = self._generate(all_codes)
# Convert output to numpy
output = out_wav.cpu().numpy().squeeze()
logger.info("Inference completed successfully")
return {
"audio": output,
"text": text
}
except Exception as e:
logger.error(f"Inference failed: {str(e)}")
raise
if __name__ == "__main__":
# Example usage
import librosa
# Initialize model
model = InferenceRecipe("/path/to/models", device="cuda")
# Load test audio
audio, sr = librosa.load("test.wav", sr=None)
# Run inference
result = model.inference(audio, sr)
print(f"Generated {len(result['audio'])} samples of audio")
print(f"Generated text: {result['text']}") |