Spaces:
Running
Running
import sys,os | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
import torch | |
import numpy as np | |
import argparse | |
from tqdm import tqdm | |
from functools import partial | |
from argparse import RawTextHelpFormatter | |
from multiprocessing.pool import ThreadPool | |
from speaker.models.lstm import LSTMSpeakerEncoder | |
from speaker.config import SpeakerEncoderConfig | |
from speaker.utils.audio import AudioProcessor | |
from speaker.infer import read_json | |
def get_spk_wavs(dataset_path, output_path): | |
wav_files = [] | |
os.makedirs(f"./{output_path}", exist_ok=True) | |
for spks in os.listdir(dataset_path): | |
if os.path.isdir(f"./{dataset_path}/{spks}"): | |
os.makedirs(f"./{output_path}/{spks}", exist_ok=True) | |
for file in os.listdir(f"./{dataset_path}/{spks}"): | |
if file.endswith(".wav"): | |
wav_files.append(f"./{dataset_path}/{spks}/{file}") | |
elif spks.endswith(".wav"): | |
wav_files.append(f"./{dataset_path}/{spks}") | |
return wav_files | |
def process_wav(wav_file, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder): | |
waveform = speaker_encoder_ap.load_wav( | |
wav_file, sr=speaker_encoder_ap.sample_rate | |
) | |
spec = speaker_encoder_ap.melspectrogram(waveform) | |
spec = torch.from_numpy(spec.T) | |
if args.use_cuda: | |
spec = spec.cuda() | |
spec = spec.unsqueeze(0) | |
embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() | |
embed = embed.squeeze() | |
embed_path = wav_file.replace(dataset_path, output_path) | |
embed_path = embed_path.replace(".wav", ".spk") | |
np.save(embed_path, embed, allow_pickle=False) | |
def extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, concurrency): | |
bound_process_wav = partial(process_wav, dataset_path=dataset_path, output_path=output_path, args=args, speaker_encoder_ap=speaker_encoder_ap, speaker_encoder=speaker_encoder) | |
with ThreadPool(concurrency) as pool: | |
list(tqdm(pool.imap(bound_process_wav, wav_files), total=len(wav_files))) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="""Compute embedding vectors for each wav file in a dataset.""", | |
formatter_class=RawTextHelpFormatter, | |
) | |
parser.add_argument("dataset_path", type=str, help="Path to dataset waves.") | |
parser.add_argument( | |
"output_path", type=str, help="path for output speaker/speaker_wavs.npy." | |
) | |
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) | |
parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1) | |
args = parser.parse_args() | |
dataset_path = args.dataset_path | |
output_path = args.output_path | |
thread_count = args.thread_count | |
# model | |
args.model_path = os.path.join("speaker_pretrain", "best_model.pth.tar") | |
args.config_path = os.path.join("speaker_pretrain", "config.json") | |
# config | |
config_dict = read_json(args.config_path) | |
# model | |
config = SpeakerEncoderConfig(config_dict) | |
config.from_dict(config_dict) | |
speaker_encoder = LSTMSpeakerEncoder( | |
config.model_params["input_dim"], | |
config.model_params["proj_dim"], | |
config.model_params["lstm_dim"], | |
config.model_params["num_lstm_layers"], | |
) | |
speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda) | |
# preprocess | |
speaker_encoder_ap = AudioProcessor(**config.audio) | |
# normalize the input audio level and trim silences | |
speaker_encoder_ap.do_sound_norm = True | |
speaker_encoder_ap.do_trim_silence = True | |
wav_files = get_spk_wavs(dataset_path, output_path) | |
if thread_count == 0: | |
process_num = os.cpu_count() | |
else: | |
process_num = thread_count | |
extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, process_num) |