|
import numpy as np |
|
import torch |
|
|
|
from TTS.vocoder.models.parallel_wavegan_generator import ParallelWaveganGenerator |
|
|
|
|
|
def test_pwgan_generator(): |
|
model = ParallelWaveganGenerator( |
|
in_channels=1, |
|
out_channels=1, |
|
kernel_size=3, |
|
num_res_blocks=30, |
|
stacks=3, |
|
res_channels=64, |
|
gate_channels=128, |
|
skip_channels=64, |
|
aux_channels=80, |
|
dropout=0.0, |
|
bias=True, |
|
use_weight_norm=True, |
|
upsample_factors=[4, 4, 4, 4], |
|
) |
|
dummy_c = torch.rand((2, 80, 5)) |
|
output = model(dummy_c) |
|
assert np.all(output.shape == (2, 1, 5 * 256)), output.shape |
|
model.remove_weight_norm() |
|
output = model.inference(dummy_c) |
|
assert np.all(output.shape == (2, 1, (5 + 4) * 256)) |
|
|