|
from moshi import models |
|
loaders = models.loaders |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
from pydub import AudioSegment |
|
import numpy as np |
|
|
|
MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors' |
|
DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16' |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME) |
|
mimi = loaders.get_mimi(mimi_weight, device=device) |
|
|
|
def encode_audio(mimi, wav, device): |
|
frame_size = int(mimi.sample_rate / mimi.frame_rate) |
|
all_codes = [] |
|
with torch.no_grad(), mimi.streaming(batch_size=1): |
|
for offset in range(0, wav.shape[-1], frame_size): |
|
frame = wav[:, :, offset: offset + frame_size] |
|
codes = mimi.encode(frame.to(device)) |
|
assert codes.shape[-1] == 1, codes.shape |
|
all_codes.append(codes) |
|
|
|
return all_codes |
|
|
|
|
|
def load_audio(wav_path, mimi): |
|
audio = AudioSegment.from_wav(wav_path) |
|
samples = np.array(audio.get_array_of_samples()) |
|
samples = samples.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**31) |
|
wav = torch.from_numpy(samples).float().unsqueeze(0).unsqueeze(0) |
|
|
|
if audio.frame_rate != mimi.sample_rate: |
|
wav = torch.nn.functional.interpolate(wav, scale_factor=mimi.sample_rate/audio.frame_rate, mode='linear', align_corners=False) |
|
|
|
frame_size = int(mimi.sample_rate / mimi.frame_rate) |
|
wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size] |
|
|
|
return wav |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|