moshi_general / mimi_tokenizer.py
tezuesh's picture
Upload folder using huggingface_hub
22d5f88 verified
from moshi import models
loaders = models.loaders
from huggingface_hub import hf_hub_download
import torch
from pydub import AudioSegment
import numpy as np
MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors'
DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16'
device = "cuda" if torch.cuda.is_available() else "cpu"
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device=device)
def encode_audio(mimi, wav, device):
frame_size = int(mimi.sample_rate / mimi.frame_rate)
all_codes = []
with torch.no_grad(), mimi.streaming(batch_size=1):
for offset in range(0, wav.shape[-1], frame_size):
frame = wav[:, :, offset: offset + frame_size]
codes = mimi.encode(frame.to(device))
assert codes.shape[-1] == 1, codes.shape
all_codes.append(codes)
return all_codes
def load_audio(wav_path, mimi):
audio = AudioSegment.from_wav(wav_path)
samples = np.array(audio.get_array_of_samples())
samples = samples.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**31)
wav = torch.from_numpy(samples).float().unsqueeze(0).unsqueeze(0)
if audio.frame_rate != mimi.sample_rate:
wav = torch.nn.functional.interpolate(wav, scale_factor=mimi.sample_rate/audio.frame_rate, mode='linear', align_corners=False)
frame_size = int(mimi.sample_rate / mimi.frame_rate)
wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size]
return wav