File size: 1,513 Bytes
22d5f88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from moshi import models
loaders = models.loaders
from huggingface_hub import hf_hub_download
import torch
from pydub import AudioSegment
import numpy as np
MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors'
DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16'
device = "cuda" if torch.cuda.is_available() else "cpu"
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device=device)
def encode_audio(mimi, wav, device):
frame_size = int(mimi.sample_rate / mimi.frame_rate)
all_codes = []
with torch.no_grad(), mimi.streaming(batch_size=1):
for offset in range(0, wav.shape[-1], frame_size):
frame = wav[:, :, offset: offset + frame_size]
codes = mimi.encode(frame.to(device))
assert codes.shape[-1] == 1, codes.shape
all_codes.append(codes)
return all_codes
def load_audio(wav_path, mimi):
audio = AudioSegment.from_wav(wav_path)
samples = np.array(audio.get_array_of_samples())
samples = samples.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**31)
wav = torch.from_numpy(samples).float().unsqueeze(0).unsqueeze(0)
if audio.frame_rate != mimi.sample_rate:
wav = torch.nn.functional.interpolate(wav, scale_factor=mimi.sample_rate/audio.frame_rate, mode='linear', align_corners=False)
frame_size = int(mimi.sample_rate / mimi.frame_rate)
wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size]
return wav
|