from moshi import models loaders = models.loaders from huggingface_hub import hf_hub_download import torch from pydub import AudioSegment import numpy as np MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors' DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16' device = "cuda" if torch.cuda.is_available() else "cpu" mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME) mimi = loaders.get_mimi(mimi_weight, device=device) def encode_audio(mimi, wav, device): frame_size = int(mimi.sample_rate / mimi.frame_rate) all_codes = [] with torch.no_grad(), mimi.streaming(batch_size=1): for offset in range(0, wav.shape[-1], frame_size): frame = wav[:, :, offset: offset + frame_size] codes = mimi.encode(frame.to(device)) assert codes.shape[-1] == 1, codes.shape all_codes.append(codes) return all_codes def load_audio(wav_path, mimi): audio = AudioSegment.from_wav(wav_path) samples = np.array(audio.get_array_of_samples()) samples = samples.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**31) wav = torch.from_numpy(samples).float().unsqueeze(0).unsqueeze(0) if audio.frame_rate != mimi.sample_rate: wav = torch.nn.functional.interpolate(wav, scale_factor=mimi.sample_rate/audio.frame_rate, mode='linear', align_corners=False) frame_size = int(mimi.sample_rate / mimi.frame_rate) wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size] return wav