|
import os |
|
import random |
|
import unittest |
|
from copy import deepcopy |
|
|
|
import torch |
|
|
|
from tests import get_tests_output_path |
|
from TTS.tts.configs.overflow_config import OverflowConfig |
|
from TTS.tts.layers.overflow.common_layers import Encoder, Outputnet, OverflowUtils |
|
from TTS.tts.layers.overflow.decoder import Decoder |
|
from TTS.tts.layers.overflow.neural_hmm import EmissionModel, NeuralHMM, TransitionModel |
|
from TTS.tts.models.overflow import Overflow |
|
from TTS.tts.utils.helpers import sequence_mask |
|
from TTS.utils.audio import AudioProcessor |
|
|
|
|
|
|
|
torch.manual_seed(1) |
|
use_cuda = torch.cuda.is_available() |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
config_global = OverflowConfig(num_chars=24) |
|
ap = AudioProcessor.init_from_config(config_global) |
|
|
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json") |
|
output_path = os.path.join(get_tests_output_path(), "train_outputs") |
|
parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt") |
|
|
|
torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path) |
|
|
|
|
|
def _create_inputs(batch_size=8): |
|
max_len_t, max_len_m = random.randint(25, 50), random.randint(50, 80) |
|
input_dummy = torch.randint(0, 24, (batch_size, max_len_t)).long().to(device) |
|
input_lengths = torch.randint(20, max_len_t, (batch_size,)).long().to(device).sort(descending=True)[0] |
|
input_lengths[0] = max_len_t |
|
input_dummy = input_dummy * sequence_mask(input_lengths) |
|
mel_spec = torch.randn(batch_size, max_len_m, config_global.audio["num_mels"]).to(device) |
|
mel_lengths = torch.randint(40, max_len_m, (batch_size,)).long().to(device).sort(descending=True)[0] |
|
mel_lengths[0] = max_len_m |
|
mel_spec = mel_spec * sequence_mask(mel_lengths).unsqueeze(2) |
|
return input_dummy, input_lengths, mel_spec, mel_lengths |
|
|
|
|
|
def get_model(config=None): |
|
if config is None: |
|
config = config_global |
|
config.mel_statistics_parameter_path = parameter_path |
|
model = Overflow(config) |
|
model = model.to(device) |
|
return model |
|
|
|
|
|
def reset_all_weights(model): |
|
""" |
|
refs: |
|
- https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6 |
|
- https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch |
|
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html |
|
""" |
|
|
|
@torch.no_grad() |
|
def weight_reset(m): |
|
|
|
reset_parameters = getattr(m, "reset_parameters", None) |
|
if callable(reset_parameters): |
|
m.reset_parameters() |
|
|
|
|
|
model.apply(fn=weight_reset) |
|
|
|
|
|
class TestOverflow(unittest.TestCase): |
|
def test_forward(self): |
|
model = get_model() |
|
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() |
|
outputs = model(input_dummy, input_lengths, mel_spec, mel_lengths) |
|
self.assertEqual(outputs["log_probs"].shape, (input_dummy.shape[0],)) |
|
self.assertEqual(model.state_per_phone * max(input_lengths), outputs["alignments"].shape[2]) |
|
|
|
def test_inference(self): |
|
model = get_model() |
|
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() |
|
output_dict = model.inference(input_dummy) |
|
self.assertEqual(output_dict["model_outputs"].shape[2], config_global.out_channels) |
|
|
|
def test_init_from_config(self): |
|
config = deepcopy(config_global) |
|
config.mel_statistics_parameter_path = parameter_path |
|
config.prenet_dim = 256 |
|
model = Overflow.init_from_config(config_global) |
|
self.assertEqual(model.prenet_dim, config.prenet_dim) |
|
|
|
|
|
class TestOverflowEncoder(unittest.TestCase): |
|
@staticmethod |
|
def get_encoder(state_per_phone): |
|
config = deepcopy(config_global) |
|
config.state_per_phone = state_per_phone |
|
config.num_chars = 24 |
|
return Encoder(config.num_chars, config.state_per_phone, config.prenet_dim, config.encoder_n_convolutions).to( |
|
device |
|
) |
|
|
|
def test_forward_with_state_per_phone_multiplication(self): |
|
for s_p_p in [1, 2, 3]: |
|
input_dummy, input_lengths, _, _ = _create_inputs() |
|
model = self.get_encoder(s_p_p) |
|
x, x_len = model(input_dummy, input_lengths) |
|
self.assertEqual(x.shape[1], input_dummy.shape[1] * s_p_p) |
|
|
|
def test_inference_with_state_per_phone_multiplication(self): |
|
for s_p_p in [1, 2, 3]: |
|
input_dummy, input_lengths, _, _ = _create_inputs() |
|
model = self.get_encoder(s_p_p) |
|
x, x_len = model.inference(input_dummy, input_lengths) |
|
self.assertEqual(x.shape[1], input_dummy.shape[1] * s_p_p) |
|
|
|
|
|
class TestOverflowUtils(unittest.TestCase): |
|
def test_logsumexp(self): |
|
a = torch.randn(10) |
|
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all()) |
|
|
|
a = torch.zeros(10) |
|
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all()) |
|
|
|
a = torch.ones(10) |
|
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all()) |
|
|
|
|
|
class TestOverflowDecoder(unittest.TestCase): |
|
@staticmethod |
|
def _get_decoder(num_flow_blocks_dec=None, hidden_channels_dec=None, reset_weights=True): |
|
config = deepcopy(config_global) |
|
config.num_flow_blocks_dec = ( |
|
num_flow_blocks_dec if num_flow_blocks_dec is not None else config.num_flow_blocks_dec |
|
) |
|
config.hidden_channels_dec = ( |
|
hidden_channels_dec if hidden_channels_dec is not None else config.hidden_channels_dec |
|
) |
|
config.dropout_p_dec = 0.0 |
|
decoder = Decoder( |
|
config.out_channels, |
|
config.hidden_channels_dec, |
|
config.kernel_size_dec, |
|
config.dilation_rate, |
|
config.num_flow_blocks_dec, |
|
config.num_block_layers, |
|
config.dropout_p_dec, |
|
config.num_splits, |
|
config.num_squeeze, |
|
config.sigmoid_scale, |
|
config.c_in_channels, |
|
).to(device) |
|
if reset_weights: |
|
reset_all_weights(decoder) |
|
return decoder |
|
|
|
def test_decoder_forward_backward(self): |
|
for num_flow_blocks_dec in [8, None]: |
|
for hidden_channels_dec in [100, None]: |
|
decoder = self._get_decoder(num_flow_blocks_dec, hidden_channels_dec) |
|
_, _, mel_spec, mel_lengths = _create_inputs() |
|
z, z_len, _ = decoder(mel_spec.transpose(1, 2), mel_lengths) |
|
mel_spec_, mel_lengths_, _ = decoder(z, z_len, reverse=True) |
|
mask = sequence_mask(z_len).unsqueeze(1) |
|
mel_spec = mel_spec[:, : z.shape[2], :].transpose(1, 2) * mask |
|
z = z * mask |
|
self.assertTrue( |
|
torch.isclose(mel_spec, mel_spec_, atol=1e-2).all(), |
|
f"num_flow_blocks_dec={num_flow_blocks_dec}, hidden_channels_dec={hidden_channels_dec}", |
|
) |
|
|
|
|
|
class TestNeuralHMM(unittest.TestCase): |
|
@staticmethod |
|
def _get_neural_hmm(deterministic_transition=None): |
|
config = deepcopy(config_global) |
|
neural_hmm = NeuralHMM( |
|
config.out_channels, |
|
config.ar_order, |
|
config.deterministic_transition if deterministic_transition is None else deterministic_transition, |
|
config.encoder_in_out_features, |
|
config.prenet_type, |
|
config.prenet_dim, |
|
config.prenet_n_layers, |
|
config.prenet_dropout, |
|
config.prenet_dropout_at_inference, |
|
config.memory_rnn_dim, |
|
config.outputnet_size, |
|
config.flat_start_params, |
|
config.std_floor, |
|
).to(device) |
|
return neural_hmm |
|
|
|
@staticmethod |
|
def _get_emission_model(): |
|
return EmissionModel().to(device) |
|
|
|
@staticmethod |
|
def _get_transition_model(): |
|
return TransitionModel().to(device) |
|
|
|
@staticmethod |
|
def _get_embedded_input(): |
|
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() |
|
input_dummy = torch.nn.Embedding(config_global.num_chars, config_global.encoder_in_out_features).to(device)( |
|
input_dummy |
|
) |
|
return input_dummy, input_lengths, mel_spec, mel_lengths |
|
|
|
def test_neural_hmm_forward(self): |
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() |
|
neural_hmm = self._get_neural_hmm() |
|
log_prob, log_alpha_scaled, transition_matrix, means = neural_hmm( |
|
input_dummy, input_lengths, mel_spec.transpose(1, 2), mel_lengths |
|
) |
|
self.assertEqual(log_prob.shape, (input_dummy.shape[0],)) |
|
self.assertEqual(log_alpha_scaled.shape, transition_matrix.shape) |
|
|
|
def test_mask_lengths(self): |
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() |
|
neural_hmm = self._get_neural_hmm() |
|
log_prob, log_alpha_scaled, transition_matrix, means = neural_hmm( |
|
input_dummy, input_lengths, mel_spec.transpose(1, 2), mel_lengths |
|
) |
|
log_c = torch.randn(mel_spec.shape[0], mel_spec.shape[1], device=device) |
|
log_c, log_alpha_scaled = neural_hmm._mask_lengths( |
|
mel_lengths, log_c, log_alpha_scaled |
|
) |
|
assertions = [] |
|
for i in range(mel_spec.shape[0]): |
|
assertions.append(log_c[i, mel_lengths[i] :].sum() == 0.0) |
|
self.assertTrue(all(assertions), "Incorrect masking") |
|
assertions = [] |
|
for i in range(mel_spec.shape[0]): |
|
assertions.append(log_alpha_scaled[i, mel_lengths[i] :, : input_lengths[i]].sum() == 0.0) |
|
self.assertTrue(all(assertions), "Incorrect masking") |
|
|
|
def test_process_ar_timestep(self): |
|
model = self._get_neural_hmm() |
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() |
|
|
|
h_post_prenet, c_post_prenet = model._init_lstm_states( |
|
input_dummy.shape[0], config_global.memory_rnn_dim, mel_spec |
|
) |
|
h_post_prenet, c_post_prenet = model._process_ar_timestep( |
|
1, |
|
mel_spec, |
|
h_post_prenet, |
|
c_post_prenet, |
|
) |
|
|
|
self.assertEqual(h_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim)) |
|
self.assertEqual(c_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim)) |
|
|
|
def test_add_go_token(self): |
|
model = self._get_neural_hmm() |
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() |
|
|
|
out = model._add_go_token(mel_spec) |
|
self.assertEqual(out.shape, mel_spec.shape) |
|
self.assertTrue((out[:, 1:] == mel_spec[:, :-1]).all(), "Go token not appended properly") |
|
|
|
def test_forward_algorithm_variables(self): |
|
model = self._get_neural_hmm() |
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() |
|
|
|
( |
|
log_c, |
|
log_alpha_scaled, |
|
transition_matrix, |
|
_, |
|
) = model._initialize_forward_algorithm_variables( |
|
mel_spec, input_dummy.shape[1] * config_global.state_per_phone |
|
) |
|
|
|
self.assertEqual(log_c.shape, (mel_spec.shape[0], mel_spec.shape[1])) |
|
self.assertEqual( |
|
log_alpha_scaled.shape, |
|
( |
|
mel_spec.shape[0], |
|
mel_spec.shape[1], |
|
input_dummy.shape[1] * config_global.state_per_phone, |
|
), |
|
) |
|
self.assertEqual( |
|
transition_matrix.shape, |
|
(mel_spec.shape[0], mel_spec.shape[1], input_dummy.shape[1] * config_global.state_per_phone), |
|
) |
|
|
|
def test_get_absorption_state_scaling_factor(self): |
|
model = self._get_neural_hmm() |
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() |
|
input_lengths = input_lengths * config_global.state_per_phone |
|
( |
|
log_c, |
|
log_alpha_scaled, |
|
transition_matrix, |
|
_, |
|
) = model._initialize_forward_algorithm_variables( |
|
mel_spec, input_dummy.shape[1] * config_global.state_per_phone |
|
) |
|
log_alpha_scaled = torch.rand_like(log_alpha_scaled).clamp(1e-3) |
|
transition_matrix = torch.randn_like(transition_matrix).sigmoid().log() |
|
sum_final_log_c = model.get_absorption_state_scaling_factor( |
|
mel_lengths, log_alpha_scaled, input_lengths, transition_matrix |
|
) |
|
|
|
text_mask = ~sequence_mask(input_lengths) |
|
transition_prob_mask = ~model.get_mask_for_last_item(input_lengths, device=input_lengths.device) |
|
|
|
outputs = [] |
|
|
|
for i in range(input_dummy.shape[0]): |
|
last_log_alpha_scaled = log_alpha_scaled[i, mel_lengths[i] - 1].masked_fill(text_mask[i], -float("inf")) |
|
log_last_transition_probability = OverflowUtils.log_clamped( |
|
torch.sigmoid(transition_matrix[i, mel_lengths[i] - 1]) |
|
).masked_fill(transition_prob_mask[i], -float("inf")) |
|
outputs.append(last_log_alpha_scaled + log_last_transition_probability) |
|
|
|
sum_final_log_c_computed = torch.logsumexp(torch.stack(outputs), dim=1) |
|
|
|
self.assertTrue(torch.isclose(sum_final_log_c_computed, sum_final_log_c).all()) |
|
|
|
def test_inference(self): |
|
model = self._get_neural_hmm() |
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() |
|
for temp in [0.334, 0.667, 1.0]: |
|
outputs = model.inference( |
|
input_dummy, input_lengths, temp, config_global.max_sampling_time, config_global.duration_threshold |
|
) |
|
self.assertEqual(outputs["hmm_outputs"].shape[-1], outputs["input_parameters"][0][0][0].shape[-1]) |
|
self.assertEqual( |
|
outputs["output_parameters"][0][0][0].shape[-1], outputs["input_parameters"][0][0][0].shape[-1] |
|
) |
|
self.assertEqual(len(outputs["alignments"]), input_dummy.shape[0]) |
|
|
|
def test_emission_model(self): |
|
model = self._get_emission_model() |
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() |
|
x_t = torch.randn(input_dummy.shape[0], config_global.out_channels).to(device) |
|
means = torch.randn(input_dummy.shape[0], input_dummy.shape[1], config_global.out_channels).to(device) |
|
std = torch.rand_like(means).to(device).clamp_(1e-3) |
|
out = model(x_t, means, std, input_lengths) |
|
self.assertEqual(out.shape, (input_dummy.shape[0], input_dummy.shape[1])) |
|
|
|
|
|
for temp in [0, 0.334, 0.667]: |
|
out = model.sample(means, std, 0) |
|
self.assertEqual(out.shape, means.shape) |
|
if temp == 0: |
|
self.assertTrue(torch.isclose(out, means).all()) |
|
|
|
def test_transition_model(self): |
|
model = self._get_transition_model() |
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() |
|
prev_t_log_scaled_alph = torch.randn(input_dummy.shape[0], input_lengths.max()).to(device) |
|
transition_vector = torch.randn(input_lengths.max()).to(device) |
|
out = model(prev_t_log_scaled_alph, transition_vector, input_lengths) |
|
self.assertEqual(out.shape, (input_dummy.shape[0], input_lengths.max())) |
|
|
|
|
|
class TestOverflowOutputNet(unittest.TestCase): |
|
@staticmethod |
|
def _get_outputnet(): |
|
config = deepcopy(config_global) |
|
outputnet = Outputnet( |
|
config.encoder_in_out_features, |
|
config.memory_rnn_dim, |
|
config.out_channels, |
|
config.outputnet_size, |
|
config.flat_start_params, |
|
config.std_floor, |
|
).to(device) |
|
return outputnet |
|
|
|
@staticmethod |
|
def _get_embedded_input(): |
|
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() |
|
input_dummy = torch.nn.Embedding(config_global.num_chars, config_global.encoder_in_out_features).to(device)( |
|
input_dummy |
|
) |
|
one_timestep_frame = torch.randn(input_dummy.shape[0], config_global.memory_rnn_dim).to(device) |
|
return input_dummy, one_timestep_frame |
|
|
|
def test_outputnet_forward_with_flat_start(self): |
|
model = self._get_outputnet() |
|
input_dummy, one_timestep_frame = self._get_embedded_input() |
|
mean, std, transition_vector = model(one_timestep_frame, input_dummy) |
|
self.assertTrue(torch.isclose(mean, torch.tensor(model.flat_start_params["mean"] * 1.0)).all()) |
|
self.assertTrue(torch.isclose(std, torch.tensor(model.flat_start_params["std"] * 1.0)).all()) |
|
self.assertTrue( |
|
torch.isclose( |
|
transition_vector.sigmoid(), torch.tensor(model.flat_start_params["transition_p"] * 1.0) |
|
).all() |
|
) |
|
|