File size: 773 Bytes
127d53c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import numpy as np
import torch
from TTS.vocoder.models.parallel_wavegan_generator import ParallelWaveganGenerator
def test_pwgan_generator():
model = ParallelWaveganGenerator(
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
bias=True,
use_weight_norm=True,
upsample_factors=[4, 4, 4, 4],
)
dummy_c = torch.rand((2, 80, 5))
output = model(dummy_c)
assert np.all(output.shape == (2, 1, 5 * 256)), output.shape
model.remove_weight_norm()
output = model.inference(dummy_c)
assert np.all(output.shape == (2, 1, (5 + 4) * 256))
|