# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Example Run Command: python ssl_tts_vc.py --ssl_model_ckpt_path --hifi_ckpt_path \ # --fastpitch_ckpt_path --source_audio_path --target_audio_path \ # --out_path import argparse import os import librosa import soundfile import torch from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer from nemo.collections.tts.models import fastpitch_ssl, hifigan, ssl_tts from nemo.collections.tts.parts.utils.tts_dataset_utils import get_base_dir def load_wav(wav_path, wav_featurizer, pad_multiple=1024): wav = wav_featurizer.process(wav_path) if (wav.shape[0] % pad_multiple) != 0: wav = torch.cat([wav, torch.zeros(pad_multiple - wav.shape[0] % pad_multiple, dtype=torch.float)]) wav = wav[:-1] return wav def get_pitch_contour(wav, pitch_mean=None, pitch_std=None, compute_mean_std=False, sample_rate=22050): f0, _, _ = librosa.pyin( wav.numpy(), fmin=librosa.note_to_hz('C2'), fmax=librosa.note_to_hz('C7'), frame_length=1024, hop_length=256, sr=sample_rate, center=True, fill_na=0.0, ) pitch_contour = torch.tensor(f0, dtype=torch.float32) _pitch_mean = pitch_contour.mean().item() _pitch_std = pitch_contour.std().item() if compute_mean_std: pitch_mean = _pitch_mean pitch_std = _pitch_std if (pitch_mean is not None) and (pitch_std is not None): pitch_contour = pitch_contour - pitch_mean pitch_contour[pitch_contour == -pitch_mean] = 0.0 pitch_contour = pitch_contour / pitch_std return pitch_contour def segment_wav(wav, segment_length=44100, hop_size=44100, min_segment_size=22050): if len(wav) < segment_length: pad = torch.zeros(segment_length - len(wav)) segment = torch.cat([wav, pad]) return [segment] else: si = 0 segments = [] while si < len(wav) - min_segment_size: segment = wav[si : si + segment_length] if len(segment) < segment_length: pad = torch.zeros(segment_length - len(segment)) segment = torch.cat([segment, pad]) segments.append(segment) si += hop_size return segments def get_speaker_embedding(ssl_model, wav_featurizer, audio_paths, duration=None, device="cpu"): all_segments = [] all_wavs = [] for audio_path in audio_paths: wav = load_wav(audio_path, wav_featurizer) segments = segment_wav(wav) all_segments += segments all_wavs.append(wav) if duration is not None and len(all_segments) >= duration: # each segment is 2 seconds with one second overlap. # so 10 segments would mean 0 to 2, 1 to 3.. 9 to 11 (11 seconds.) all_segments = all_segments[: int(duration)] break signal_batch = torch.stack(all_segments) signal_length_batch = torch.stack([torch.tensor(signal_batch.shape[1]) for _ in range(len(all_segments))]) signal_batch = signal_batch.to(device) signal_length_batch = signal_length_batch.to(device) _, speaker_embeddings, _, _, _ = ssl_model.forward_for_export( input_signal=signal_batch, input_signal_length=signal_length_batch, normalize_content=True ) speaker_embedding = torch.mean(speaker_embeddings, dim=0) l2_norm = torch.norm(speaker_embedding, p=2) speaker_embedding = speaker_embedding / l2_norm return speaker_embedding[None] def get_ssl_features_disentangled( ssl_model, wav_featurizer, audio_path, emb_type="embedding_and_probs", use_unique_tokens=False, device="cpu" ): """ Extracts content embedding, speaker embedding and duration tokens to be used as inputs for FastPitchModel_SSL synthesizer. Content embedding and speaker embedding extracted using SSLDisentangler model. Args: ssl_model: SSLDisentangler model wav_featurizer: WaveformFeaturizer object audio_path: path to audio file emb_type: Can be one of embedding_and_probs, embedding, probs, log_probs use_unique_tokens: If True, content embeddings with same predicted token are grouped and duration is different. device: device to run the model on Returns: content_embedding, speaker_embedding, duration """ wav = load_wav(audio_path, wav_featurizer) audio_signal = wav[None] audio_signal_length = torch.tensor([wav.shape[0]]) audio_signal = audio_signal.to(device) audio_signal_length = audio_signal_length.to(device) _, speaker_embedding, content_embedding, content_log_probs, encoded_len = ssl_model.forward_for_export( input_signal=audio_signal, input_signal_length=audio_signal_length, normalize_content=True ) content_embedding = content_embedding[0, : encoded_len[0].item()] content_log_probs = content_log_probs[: encoded_len[0].item(), 0, :] content_embedding = content_embedding.t() content_log_probs = content_log_probs.t() content_probs = torch.exp(content_log_probs) ssl_downsampling_factor = ssl_model._cfg.encoder.subsampling_factor if emb_type == "probs": # content embedding is only character probabilities final_content_embedding = content_probs elif emb_type == "embedding": # content embedding is only output of content head of SSL backbone final_content_embedding = content_embedding elif emb_type == "log_probs": # content embedding is only log of character probabilities final_content_embedding = content_log_probs elif emb_type == "embedding_and_probs": # content embedding is the concatenation of character probabilities and output of content head of SSL backbone final_content_embedding = torch.cat([content_embedding, content_probs], dim=0) else: raise ValueError( f"{emb_type} is not valid. Valid emb_type includes probs, embedding, log_probs or embedding_and_probs." ) duration = torch.ones(final_content_embedding.shape[1]) * ssl_downsampling_factor if use_unique_tokens: # group content embeddings with same predicted token (by averaging) and add the durations of the grouped embeddings # Eg. By default each content embedding corresponds to 4 frames of spectrogram (ssl_downsampling_factor) # If we group 3 content embeddings, the duration of the grouped embedding will be 12 frames. # This is useful for adapting the duration during inference based on the speaker. token_predictions = torch.argmax(content_probs, dim=0) content_buffer = [final_content_embedding[:, 0]] unique_content_embeddings = [] unique_tokens = [] durations = [] for _t in range(1, final_content_embedding.shape[1]): if token_predictions[_t] == token_predictions[_t - 1]: content_buffer.append(final_content_embedding[:, _t]) else: durations.append(len(content_buffer) * ssl_downsampling_factor) unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0)) content_buffer = [final_content_embedding[:, _t]] unique_tokens.append(token_predictions[_t].item()) if len(content_buffer) > 0: durations.append(len(content_buffer) * ssl_downsampling_factor) unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0)) unique_tokens.append(token_predictions[_t].item()) unique_content_embedding = torch.stack(unique_content_embeddings) final_content_embedding = unique_content_embedding.t() duration = torch.tensor(durations).float() duration = duration.to(device) return final_content_embedding[None], speaker_embedding, duration[None] def main(): parser = argparse.ArgumentParser(description='Evaluate the model') parser.add_argument('--ssl_model_ckpt_path', type=str) parser.add_argument('--hifi_ckpt_path', type=str) parser.add_argument('--fastpitch_ckpt_path', type=str) parser.add_argument('--source_audio_path', type=str) parser.add_argument('--target_audio_path', type=str) # can be a list seperated by comma parser.add_argument('--out_path', type=str) parser.add_argument('--source_target_out_pairs', type=str) parser.add_argument('--use_unique_tokens', type=int, default=0) parser.add_argument('--compute_pitch', type=int, default=0) parser.add_argument('--compute_duration', type=int, default=0) args = parser.parse_args() device = "cuda:0" if torch.cuda.is_available() else "cpu" if args.source_target_out_pairs is not None: assert args.source_audio_path is None, "source_audio_path and source_target_out_pairs are mutually exclusive" assert args.target_audio_path is None, "target_audio_path and source_target_out_pairs are mutually exclusive" assert args.out_path is None, "out_path and source_target_out_pairs are mutually exclusive" with open(args.source_target_out_pairs, "r") as f: lines = f.readlines() source_target_out_pairs = [line.strip().split(";") for line in lines] else: assert args.source_audio_path is not None, "source_audio_path is required" assert args.target_audio_path is not None, "target_audio_path is required" if args.out_path is None: source_name = os.path.basename(args.source_audio_path).split(".")[0] target_name = os.path.basename(args.target_audio_path).split(".")[0] args.out_path = "swapped_{}_{}.wav".format(source_name, target_name) source_target_out_pairs = [(args.source_audio_path, args.target_audio_path, args.out_path)] out_paths = [r[2] for r in source_target_out_pairs] out_dir = get_base_dir(out_paths) if not os.path.exists(out_dir): os.makedirs(out_dir) ssl_model = ssl_tts.SSLDisentangler.load_from_checkpoint(args.ssl_model_ckpt_path, strict=False) ssl_model = ssl_model.to(device) ssl_model.eval() vocoder = hifigan.HifiGanModel.load_from_checkpoint(args.hifi_ckpt_path).to(device) vocoder.eval() fastpitch_model = fastpitch_ssl.FastPitchModel_SSL.load_from_checkpoint(args.fastpitch_ckpt_path, strict=False) fastpitch_model = fastpitch_model.to(device) fastpitch_model.eval() fastpitch_model.non_trainable_models = {'vocoder': vocoder} fpssl_sample_rate = fastpitch_model._cfg.sample_rate wav_featurizer = WaveformFeaturizer(sample_rate=fpssl_sample_rate, int_values=False, augmentor=None) use_unique_tokens = args.use_unique_tokens == 1 compute_pitch = args.compute_pitch == 1 compute_duration = args.compute_duration == 1 for source_target_out in source_target_out_pairs: source_audio_path = source_target_out[0] target_audio_paths = source_target_out[1].split(",") out_path = source_target_out[2] with torch.no_grad(): content_embedding1, _, duration1 = get_ssl_features_disentangled( ssl_model, wav_featurizer, source_audio_path, emb_type="embedding_and_probs", use_unique_tokens=use_unique_tokens, device=device, ) speaker_embedding2 = get_speaker_embedding( ssl_model, wav_featurizer, target_audio_paths, duration=None, device=device ) pitch_contour1 = None if not compute_pitch: pitch_contour1 = get_pitch_contour( load_wav(source_audio_path, wav_featurizer), compute_mean_std=True, sample_rate=fpssl_sample_rate )[None] pitch_contour1 = pitch_contour1.to(device) wav_generated = fastpitch_model.generate_wav( content_embedding1, speaker_embedding2, pitch_contour=pitch_contour1, compute_pitch=compute_pitch, compute_duration=compute_duration, durs_gt=duration1, dataset_id=0, ) wav_generated = wav_generated[0][0] soundfile.write(out_path, wav_generated, fpssl_sample_rate) if __name__ == "__main__": main()