Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- RingFormer/LICENSE +21 -0
- RingFormer/README.md +38 -0
- RingFormer/Utils/JDC/__init__.py +1 -0
- RingFormer/Utils/JDC/__pycache__/__init__.cpython-311.pyc +0 -0
- RingFormer/Utils/JDC/__pycache__/__init__.cpython-38.pyc +0 -0
- RingFormer/Utils/JDC/__pycache__/__init__.cpython-39.pyc +0 -0
- RingFormer/Utils/JDC/__pycache__/bst.t7 +3 -0
- RingFormer/Utils/JDC/__pycache__/model.cpython-311.pyc +0 -0
- RingFormer/Utils/JDC/__pycache__/model.cpython-38.pyc +0 -0
- RingFormer/Utils/JDC/__pycache__/model.cpython-39.pyc +0 -0
- RingFormer/Utils/JDC/bst.t7 +3 -0
- RingFormer/Utils/JDC/model.py +192 -0
- RingFormer/Utils/__init__.py +1 -0
- RingFormer/Utils/__pycache__/__init__.cpython-311.pyc +0 -0
- RingFormer/Utils/__pycache__/__init__.cpython-38.pyc +0 -0
- RingFormer/Utils/__pycache__/__init__.cpython-39.pyc +0 -0
- RingFormer/__pycache__/conformer.cpython-311.pyc +0 -0
- RingFormer/__pycache__/env.cpython-311.pyc +0 -0
- RingFormer/__pycache__/meldataset.cpython-311.pyc +0 -0
- RingFormer/__pycache__/models.cpython-311.pyc +0 -0
- RingFormer/__pycache__/norm2d.cpython-311.pyc +0 -0
- RingFormer/__pycache__/stft.cpython-311.pyc +0 -0
- RingFormer/__pycache__/utils.cpython-311.pyc +0 -0
- RingFormer/config_v1.json +42 -0
- RingFormer/conformer.py +228 -0
- RingFormer/env.py +15 -0
- RingFormer/inference.ipynb +292 -0
- RingFormer/meldataset.py +203 -0
- RingFormer/models.py +943 -0
- RingFormer/norm2d.py +92 -0
- RingFormer/requirements.txt +10 -0
- RingFormer/stft.py +254 -0
- RingFormer/train.py +671 -0
- RingFormer/utils.py +58 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
RingFormer/Utils/JDC/__pycache__/bst.t7 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
RingFormer/Utils/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
|
RingFormer/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Aaron (Yinghao) Li
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
RingFormer/README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HiFTNet: A Fast High-Quality Neural Vocoder with Harmonic-plus-Noise Filter and Inverse Short Time Fourier Transform
|
2 |
+
|
3 |
+
### Yinghao Aaron Li, Cong Han, Xilin Jiang, Nima Mesgarani
|
4 |
+
|
5 |
+
> Recent advancements in speech synthesis have leveraged GAN-based networks like HiFi-GAN and BigVGAN to produce high-fidelity waveforms from mel-spectrograms. However, these networks are computationally expensive and parameter-heavy. iSTFTNet addresses these limitations by integrating inverse short-time Fourier transform (iSTFT) into the network, achieving both speed and parameter efficiency. In this paper, we introduce an extension to iSTFTNet, termed HiFTNet, which incorporates a harmonic-plus-noise source filter in the time-frequency domain that uses a sinusoidal source from the fundamental frequency (F0) inferred via a pre-trained F0 estimation network for fast inference speed. Subjective evaluations on LJSpeech show that our model significantly outperforms both iSTFTNet and HiFi-GAN, achieving ground-truth-level performance. HiFTNet also outperforms BigVGAN-base on LibriTTS for unseen speakers and achieves comparable performance to BigVGAN while being four times faster with only 1/6 of the parameters. Our work sets a new benchmark for efficient, high-quality neural vocoding, paving the way for real-time applications that demand high quality speech synthesis.
|
6 |
+
|
7 |
+
Paper: [https://arxiv.org/abs/2309.09493](https://arxiv.org/abs/2309.09493)
|
8 |
+
|
9 |
+
Audio samples: [https://hiftnet.github.io/](https://hiftnet.github.io/)
|
10 |
+
|
11 |
+
**Check our TTS work that uses HiFTNet as speech decoder for human-level speech synthesis here: https://github.com/yl4579/StyleTTS2**
|
12 |
+
|
13 |
+
## Pre-requisites
|
14 |
+
1. Python >= 3.7
|
15 |
+
2. Clone this repository:
|
16 |
+
```bash
|
17 |
+
git clone https://github.com/yl4579/HiFTNet.git
|
18 |
+
cd HiFTNet
|
19 |
+
```
|
20 |
+
3. Install python requirements:
|
21 |
+
```bash
|
22 |
+
pip install -r requirements.txt
|
23 |
+
```
|
24 |
+
|
25 |
+
## Training
|
26 |
+
```bash
|
27 |
+
python train.py --config config_v1.json --[args]
|
28 |
+
```
|
29 |
+
For the F0 model training, please refer to [yl4579/PitchExtractor](https://github.com/yl4579/PitchExtractor). This repo includes a pre-trained F0 model on LibriTTS. Still, you may want to train your own F0 model for the best performance, particularly for noisy or non-speech data, as we found that F0 estimation accuracy is essential for the vocoder performance.
|
30 |
+
|
31 |
+
## Inference
|
32 |
+
Please refer to the notebook [inference.ipynb](https://github.com/yl4579/HiFTNet/blob/main/inference.ipynb) for details.
|
33 |
+
### Pre-Trained Models
|
34 |
+
You can download the pre-trained LJSpeech model [here](https://huggingface.co/yl4579/HiFTNet/blob/main/LJSpeech/cp_hifigan.zip) and the pre-trained LibriTTS model [here](https://huggingface.co/yl4579/HiFTNet/blob/main/LibriTTS/cp_hifigan.zip). The pre-trained models contain parameters of the optimizers and discriminators that can be used for fine-tuning.
|
35 |
+
|
36 |
+
## References
|
37 |
+
- [rishikksh20/iSTFTNet-pytorch](https://github.com/rishikksh20/iSTFTNet-pytorch)
|
38 |
+
- [nii-yamagishilab/project-NN-Pytorch-scripts/project/01-nsf](https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/tree/master/project/01-nsf)
|
RingFormer/Utils/JDC/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
RingFormer/Utils/JDC/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (154 Bytes). View file
|
|
RingFormer/Utils/JDC/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (129 Bytes). View file
|
|
RingFormer/Utils/JDC/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (129 Bytes). View file
|
|
RingFormer/Utils/JDC/__pycache__/bst.t7
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a7939fda04b8ef20365a6e6132e8190addd1ffdfb4c3f7a6892d0d7d6fe70d8
|
3 |
+
size 63049083
|
RingFormer/Utils/JDC/__pycache__/model.cpython-311.pyc
ADDED
Binary file (10.3 kB). View file
|
|
RingFormer/Utils/JDC/__pycache__/model.cpython-38.pyc
ADDED
Binary file (4.73 kB). View file
|
|
RingFormer/Utils/JDC/__pycache__/model.cpython-39.pyc
ADDED
Binary file (4.72 kB). View file
|
|
RingFormer/Utils/JDC/bst.t7
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
|
3 |
+
size 21029926
|
RingFormer/Utils/JDC/model.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of model from:
|
3 |
+
Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
|
4 |
+
Convolutional Recurrent Neural Networks" (2019)
|
5 |
+
Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
class JDCNet(nn.Module):
|
11 |
+
"""
|
12 |
+
Joint Detection and Classification Network model for singing voice melody.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
|
16 |
+
super().__init__()
|
17 |
+
self.num_class = num_class
|
18 |
+
|
19 |
+
# input = (b, 1, 31, 513), b = batch size
|
20 |
+
self.conv_block = nn.Sequential(
|
21 |
+
nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513)
|
22 |
+
nn.BatchNorm2d(num_features=64),
|
23 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
24 |
+
nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
|
25 |
+
)
|
26 |
+
|
27 |
+
# res blocks
|
28 |
+
self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128)
|
29 |
+
self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32)
|
30 |
+
self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
|
31 |
+
|
32 |
+
# pool block
|
33 |
+
self.pool_block = nn.Sequential(
|
34 |
+
nn.BatchNorm2d(num_features=256),
|
35 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
36 |
+
nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
|
37 |
+
nn.Dropout(p=0.2),
|
38 |
+
)
|
39 |
+
|
40 |
+
# maxpool layers (for auxiliary network inputs)
|
41 |
+
# in = (b, 64, 31, 128) from conv_block, out = (b, 64, 31, 4)
|
42 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 32)) # 128 / 32 = 4
|
43 |
+
# in = (b, 128, 31, 64) from res_block1, out = (b, 128, 31, 4)
|
44 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 16)) # 64 / 16 = 4
|
45 |
+
# in = (b, 192, 31, 32) from res_block2, out = (b, 192, 31, 4)
|
46 |
+
self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 8)) # 32 / 8 = 4
|
47 |
+
|
48 |
+
|
49 |
+
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
|
50 |
+
self.detector_conv = nn.Sequential(
|
51 |
+
nn.Conv2d(640, 256, 1, bias=False),
|
52 |
+
nn.BatchNorm2d(256),
|
53 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
54 |
+
nn.Dropout(p=0.2),
|
55 |
+
)
|
56 |
+
|
57 |
+
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
58 |
+
self.bilstm_classifier = nn.LSTM(
|
59 |
+
input_size=1024, hidden_size=256,
|
60 |
+
batch_first=True, bidirectional=True) # (b, 31, 512)
|
61 |
+
|
62 |
+
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
63 |
+
self.bilstm_detector = nn.LSTM(
|
64 |
+
input_size=1024, hidden_size=256,
|
65 |
+
batch_first=True, bidirectional=True) # (b, 31, 512)
|
66 |
+
|
67 |
+
# input: (b * 31, 512)
|
68 |
+
self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class)
|
69 |
+
|
70 |
+
# input: (b * 31, 512)
|
71 |
+
self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier
|
72 |
+
|
73 |
+
# initialize weights
|
74 |
+
self.apply(self.init_weights)
|
75 |
+
|
76 |
+
def get_feature_GAN(self, x):
|
77 |
+
seq_len = x.shape[-2]
|
78 |
+
x = x.float().transpose(-1, -2)
|
79 |
+
|
80 |
+
convblock_out = self.conv_block(x)
|
81 |
+
|
82 |
+
resblock1_out = self.res_block1(convblock_out)
|
83 |
+
resblock2_out = self.res_block2(resblock1_out)
|
84 |
+
resblock3_out = self.res_block3(resblock2_out)
|
85 |
+
poolblock_out = self.pool_block[0](resblock3_out)
|
86 |
+
poolblock_out = self.pool_block[1](poolblock_out)
|
87 |
+
|
88 |
+
return poolblock_out.transpose(-1, -2)
|
89 |
+
|
90 |
+
def get_feature(self, x):
|
91 |
+
seq_len = x.shape[-2]
|
92 |
+
x = x.float().transpose(-1, -2)
|
93 |
+
|
94 |
+
convblock_out = self.conv_block(x)
|
95 |
+
|
96 |
+
resblock1_out = self.res_block1(convblock_out)
|
97 |
+
resblock2_out = self.res_block2(resblock1_out)
|
98 |
+
resblock3_out = self.res_block3(resblock2_out)
|
99 |
+
poolblock_out = self.pool_block[0](resblock3_out)
|
100 |
+
poolblock_out = self.pool_block[1](poolblock_out)
|
101 |
+
|
102 |
+
return self.pool_block[2](poolblock_out)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
"""
|
106 |
+
Returns:
|
107 |
+
classification_prediction, detection_prediction
|
108 |
+
sizes: (b, 31, 722), (b, 31, 2)
|
109 |
+
"""
|
110 |
+
###############################
|
111 |
+
# forward pass for classifier #
|
112 |
+
###############################
|
113 |
+
seq_len = x.shape[-1]
|
114 |
+
x = x.float().transpose(-1, -2)
|
115 |
+
|
116 |
+
convblock_out = self.conv_block(x)
|
117 |
+
|
118 |
+
resblock1_out = self.res_block1(convblock_out)
|
119 |
+
resblock2_out = self.res_block2(resblock1_out)
|
120 |
+
resblock3_out = self.res_block3(resblock2_out)
|
121 |
+
|
122 |
+
|
123 |
+
poolblock_out = self.pool_block[0](resblock3_out)
|
124 |
+
poolblock_out = self.pool_block[1](poolblock_out)
|
125 |
+
GAN_feature = poolblock_out.transpose(-1, -2)
|
126 |
+
poolblock_out = self.pool_block[2](poolblock_out)
|
127 |
+
|
128 |
+
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
|
129 |
+
classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 1024))
|
130 |
+
classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states
|
131 |
+
|
132 |
+
classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
|
133 |
+
classifier_out = self.classifier(classifier_out)
|
134 |
+
classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class)
|
135 |
+
|
136 |
+
# sizes: (b, 31, 722), (b, 31, 2)
|
137 |
+
# classifier output consists of predicted pitch classes per frame
|
138 |
+
# detector output consists of: (isvoice, notvoice) estimates per frame
|
139 |
+
return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def init_weights(m):
|
143 |
+
if isinstance(m, nn.Linear):
|
144 |
+
nn.init.kaiming_uniform_(m.weight)
|
145 |
+
if m.bias is not None:
|
146 |
+
nn.init.constant_(m.bias, 0)
|
147 |
+
elif isinstance(m, nn.Conv2d):
|
148 |
+
nn.init.xavier_normal_(m.weight)
|
149 |
+
elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
|
150 |
+
for p in m.parameters():
|
151 |
+
if p.data is None:
|
152 |
+
continue
|
153 |
+
|
154 |
+
if len(p.shape) >= 2:
|
155 |
+
nn.init.orthogonal_(p.data)
|
156 |
+
else:
|
157 |
+
nn.init.normal_(p.data)
|
158 |
+
|
159 |
+
|
160 |
+
class ResBlock(nn.Module):
|
161 |
+
def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
|
162 |
+
super().__init__()
|
163 |
+
self.downsample = in_channels != out_channels
|
164 |
+
|
165 |
+
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
|
166 |
+
self.pre_conv = nn.Sequential(
|
167 |
+
nn.BatchNorm2d(num_features=in_channels),
|
168 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
169 |
+
nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
|
170 |
+
)
|
171 |
+
|
172 |
+
# conv layers
|
173 |
+
self.conv = nn.Sequential(
|
174 |
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
175 |
+
kernel_size=3, padding=1, bias=False),
|
176 |
+
nn.BatchNorm2d(out_channels),
|
177 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
178 |
+
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
179 |
+
)
|
180 |
+
|
181 |
+
# 1 x 1 convolution layer to match the feature dimensions
|
182 |
+
self.conv1by1 = None
|
183 |
+
if self.downsample:
|
184 |
+
self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
x = self.pre_conv(x)
|
188 |
+
if self.downsample:
|
189 |
+
x = self.conv(x) + self.conv1by1(x)
|
190 |
+
else:
|
191 |
+
x = self.conv(x) + x
|
192 |
+
return x
|
RingFormer/Utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
RingFormer/Utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (150 Bytes). View file
|
|
RingFormer/Utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (125 Bytes). View file
|
|
RingFormer/Utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (125 Bytes). View file
|
|
RingFormer/__pycache__/conformer.cpython-311.pyc
ADDED
Binary file (11.2 kB). View file
|
|
RingFormer/__pycache__/env.cpython-311.pyc
ADDED
Binary file (1.32 kB). View file
|
|
RingFormer/__pycache__/meldataset.cpython-311.pyc
ADDED
Binary file (12.9 kB). View file
|
|
RingFormer/__pycache__/models.cpython-311.pyc
ADDED
Binary file (49.7 kB). View file
|
|
RingFormer/__pycache__/norm2d.cpython-311.pyc
ADDED
Binary file (4.6 kB). View file
|
|
RingFormer/__pycache__/stft.cpython-311.pyc
ADDED
Binary file (13.4 kB). View file
|
|
RingFormer/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (3.46 kB). View file
|
|
RingFormer/config_v1.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"F0_path": "/home/ubuntu/Darya_Speech/Utils/JDC/epoch_00079.pth",
|
3 |
+
|
4 |
+
"resblock": "1",
|
5 |
+
"num_gpus": 1,
|
6 |
+
"batch_size": 18,
|
7 |
+
"learning_rate": 0.0002,
|
8 |
+
"adam_b1": 0.8,
|
9 |
+
"adam_b2": 0.99,
|
10 |
+
"lr_decay": 0.999,
|
11 |
+
"seed": 1234,
|
12 |
+
|
13 |
+
|
14 |
+
"upsample_rates": [16,8],
|
15 |
+
"upsample_kernel_sizes": [32, 16],
|
16 |
+
"upsample_initial_channel": 512,
|
17 |
+
"resblock_kernel_sizes": [3,7,11],
|
18 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
19 |
+
"gen_istft_n_fft": 32,
|
20 |
+
"gen_istft_hop_size": 4,
|
21 |
+
|
22 |
+
|
23 |
+
"segment_size": 65536,
|
24 |
+
"num_mels": 128,
|
25 |
+
"n_fft": 2048,
|
26 |
+
"hop_size": 512,
|
27 |
+
"win_size": 2048,
|
28 |
+
|
29 |
+
"sampling_rate": 44100,
|
30 |
+
|
31 |
+
"fmin": 0,
|
32 |
+
"fmax": null,
|
33 |
+
"fmax_for_loss": null,
|
34 |
+
|
35 |
+
"num_workers": 8,
|
36 |
+
|
37 |
+
"dist_config": {
|
38 |
+
"dist_backend": "nccl",
|
39 |
+
"dist_url": "tcp://localhost:54321",
|
40 |
+
"world_size": 1
|
41 |
+
}
|
42 |
+
}
|
RingFormer/conformer.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from einops.layers.torch import Rearrange
|
6 |
+
from ring_attention_pytorch import RingAttention
|
7 |
+
|
8 |
+
# helper functions
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def default(val, d):
|
16 |
+
return val if exists(val) else d
|
17 |
+
|
18 |
+
|
19 |
+
def calc_same_padding(kernel_size):
|
20 |
+
pad = kernel_size // 2
|
21 |
+
return (pad, pad - (kernel_size + 1) % 2)
|
22 |
+
|
23 |
+
|
24 |
+
# helper classes
|
25 |
+
|
26 |
+
|
27 |
+
class Swish(nn.Module):
|
28 |
+
def forward(self, x):
|
29 |
+
return x * x.sigmoid()
|
30 |
+
|
31 |
+
|
32 |
+
class GLU(nn.Module):
|
33 |
+
def __init__(self, dim):
|
34 |
+
super().__init__()
|
35 |
+
self.dim = dim
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
out, gate = x.chunk(2, dim=self.dim)
|
39 |
+
return out * gate.sigmoid()
|
40 |
+
|
41 |
+
|
42 |
+
class DepthWiseConv1d(nn.Module):
|
43 |
+
def __init__(self, chan_in, chan_out, kernel_size, padding):
|
44 |
+
super().__init__()
|
45 |
+
self.padding = padding
|
46 |
+
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
x = F.pad(x, self.padding)
|
50 |
+
return self.conv(x)
|
51 |
+
|
52 |
+
|
53 |
+
# attention, feedforward, and conv module
|
54 |
+
|
55 |
+
|
56 |
+
class Scale(nn.Module):
|
57 |
+
def __init__(self, scale, fn):
|
58 |
+
super().__init__()
|
59 |
+
self.fn = fn
|
60 |
+
self.scale = scale
|
61 |
+
|
62 |
+
def forward(self, x, **kwargs):
|
63 |
+
return self.fn(x, **kwargs) * self.scale
|
64 |
+
|
65 |
+
|
66 |
+
class PreNorm(nn.Module):
|
67 |
+
def __init__(self, dim, fn):
|
68 |
+
super().__init__()
|
69 |
+
self.fn = fn
|
70 |
+
self.norm = nn.LayerNorm(dim)
|
71 |
+
|
72 |
+
def forward(self, x, **kwargs):
|
73 |
+
|
74 |
+
x = self.norm(x.to(x.device))
|
75 |
+
|
76 |
+
out = self.fn(x.to(x.device), **kwargs)
|
77 |
+
|
78 |
+
return out
|
79 |
+
|
80 |
+
|
81 |
+
class FeedForward(nn.Module):
|
82 |
+
def __init__(self, dim, mult=4, dropout=0.0):
|
83 |
+
super().__init__()
|
84 |
+
self.net = nn.Sequential(
|
85 |
+
nn.Linear(dim, dim * mult),
|
86 |
+
Swish(),
|
87 |
+
nn.Dropout(dropout),
|
88 |
+
nn.Linear(dim * mult, dim),
|
89 |
+
nn.Dropout(dropout),
|
90 |
+
)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
return self.net(x)
|
94 |
+
|
95 |
+
|
96 |
+
class ConformerConvModule(nn.Module):
|
97 |
+
def __init__(
|
98 |
+
self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0
|
99 |
+
):
|
100 |
+
super().__init__()
|
101 |
+
|
102 |
+
inner_dim = dim * expansion_factor
|
103 |
+
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
|
104 |
+
|
105 |
+
self.net = nn.Sequential(
|
106 |
+
nn.LayerNorm(dim),
|
107 |
+
Rearrange("b n c -> b c n"),
|
108 |
+
nn.Conv1d(dim, inner_dim * 2, 1),
|
109 |
+
GLU(dim=1),
|
110 |
+
DepthWiseConv1d(
|
111 |
+
inner_dim, inner_dim, kernel_size=kernel_size, padding=padding
|
112 |
+
),
|
113 |
+
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
|
114 |
+
Swish(),
|
115 |
+
nn.Conv1d(inner_dim, dim, 1),
|
116 |
+
Rearrange("b c n -> b n c"),
|
117 |
+
nn.Dropout(dropout),
|
118 |
+
)
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
return self.net(x)
|
122 |
+
|
123 |
+
|
124 |
+
# Conformer Block
|
125 |
+
|
126 |
+
|
127 |
+
class ConformerBlock(nn.Module):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
*,
|
131 |
+
dim,
|
132 |
+
dim_head=64,
|
133 |
+
heads=8,
|
134 |
+
ff_mult=4,
|
135 |
+
conv_expansion_factor=2,
|
136 |
+
conv_kernel_size=31,
|
137 |
+
attn_dropout=0.0,
|
138 |
+
ff_dropout=0.0,
|
139 |
+
conv_dropout=0.0,
|
140 |
+
conv_causal=False
|
141 |
+
):
|
142 |
+
super().__init__()
|
143 |
+
self.ff1 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
|
144 |
+
self.attn = RingAttention(
|
145 |
+
dim=dim,
|
146 |
+
dim_head=dim_head,
|
147 |
+
heads=heads,
|
148 |
+
causal=True,
|
149 |
+
auto_shard_seq=False, # doesn't work on multi-gpu setup for some reason
|
150 |
+
ring_attn=True,
|
151 |
+
ring_seq_size=512,
|
152 |
+
)
|
153 |
+
self.self_attn_dropout = torch.nn.Dropout(attn_dropout)
|
154 |
+
self.conv = ConformerConvModule(
|
155 |
+
dim=dim,
|
156 |
+
causal=conv_causal,
|
157 |
+
expansion_factor=conv_expansion_factor,
|
158 |
+
kernel_size=conv_kernel_size,
|
159 |
+
dropout=conv_dropout,
|
160 |
+
)
|
161 |
+
self.ff2 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
|
162 |
+
|
163 |
+
self.attn = PreNorm(dim, self.attn)
|
164 |
+
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
|
165 |
+
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
|
166 |
+
|
167 |
+
self.post_norm = nn.LayerNorm(dim)
|
168 |
+
|
169 |
+
|
170 |
+
def forward(self, x, mask=None):
|
171 |
+
x_ff1 = self.ff1(x) + x
|
172 |
+
|
173 |
+
x = self.attn(x, mask=mask)
|
174 |
+
x = self.self_attn_dropout(x)
|
175 |
+
x = x + x_ff1
|
176 |
+
x = self.conv(x) + x
|
177 |
+
x = self.ff2(x) + x
|
178 |
+
x = self.post_norm(x)
|
179 |
+
return x
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
# Conformer
|
184 |
+
|
185 |
+
|
186 |
+
class Conformer(nn.Module):
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
|
190 |
+
dim,
|
191 |
+
*,
|
192 |
+
depth,
|
193 |
+
dim_head=64,
|
194 |
+
heads=8,
|
195 |
+
ff_mult=4,
|
196 |
+
conv_expansion_factor=2,
|
197 |
+
conv_kernel_size=31,
|
198 |
+
attn_dropout=0.0,
|
199 |
+
ff_dropout=0.0,
|
200 |
+
conv_dropout=0.0,
|
201 |
+
conv_causal=False
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
self.dim = dim
|
205 |
+
|
206 |
+
self.layers = nn.ModuleList([])
|
207 |
+
|
208 |
+
for _ in range(depth):
|
209 |
+
self.layers.append(
|
210 |
+
ConformerBlock(
|
211 |
+
dim=dim,
|
212 |
+
dim_head=dim_head,
|
213 |
+
heads=heads,
|
214 |
+
ff_mult=ff_mult,
|
215 |
+
conv_expansion_factor=conv_expansion_factor,
|
216 |
+
conv_kernel_size=conv_kernel_size,
|
217 |
+
conv_causal=conv_causal,
|
218 |
+
)
|
219 |
+
)
|
220 |
+
|
221 |
+
|
222 |
+
def forward(self, x):
|
223 |
+
|
224 |
+
for block in self.layers:
|
225 |
+
|
226 |
+
x = block(x)
|
227 |
+
|
228 |
+
return x
|
RingFormer/env.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
|
5 |
+
class AttrDict(dict):
|
6 |
+
def __init__(self, *args, **kwargs):
|
7 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
8 |
+
self.__dict__ = self
|
9 |
+
|
10 |
+
|
11 |
+
def build_env(config, config_name, path):
|
12 |
+
t_path = os.path.join(path, config_name)
|
13 |
+
if config != t_path:
|
14 |
+
os.makedirs(path, exist_ok=True)
|
15 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
RingFormer/inference.ipynb
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "7b82eb58",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stderr",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"/home/ubuntu/miniconda3/envs/respair/lib/python3.11/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
|
14 |
+
" self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"name": "stdout",
|
19 |
+
"output_type": "stream",
|
20 |
+
"text": [
|
21 |
+
"/home/ubuntu/RINGFORMER\n"
|
22 |
+
]
|
23 |
+
}
|
24 |
+
],
|
25 |
+
"source": [
|
26 |
+
"%cd /home/ubuntu/RINGFORMER\n",
|
27 |
+
"\n",
|
28 |
+
"from __future__ import absolute_import, division, print_function, unicode_literals\n",
|
29 |
+
"\n",
|
30 |
+
"import glob\n",
|
31 |
+
"import os\n",
|
32 |
+
"import argparse\n",
|
33 |
+
"import json\n",
|
34 |
+
"import torch\n",
|
35 |
+
"from scipy.io.wavfile import write\n",
|
36 |
+
"from env import AttrDict\n",
|
37 |
+
"from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav\n",
|
38 |
+
"from models import Generator\n",
|
39 |
+
"from stft import TorchSTFT\n",
|
40 |
+
"\n",
|
41 |
+
"from Utils.JDC.model import JDCNet"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": 2,
|
47 |
+
"id": "3ee13ffd",
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"h = None\n",
|
52 |
+
"device = None\n",
|
53 |
+
"\n",
|
54 |
+
"\n",
|
55 |
+
"def load_checkpoint(filepath, device):\n",
|
56 |
+
" assert os.path.isfile(filepath)\n",
|
57 |
+
" print(\"Loading '{}'\".format(filepath))\n",
|
58 |
+
" checkpoint_dict = torch.load(filepath, map_location=device)\n",
|
59 |
+
" print(\"Complete.\")\n",
|
60 |
+
" return checkpoint_dict\n",
|
61 |
+
"\n",
|
62 |
+
"\n",
|
63 |
+
"def get_mel(x):\n",
|
64 |
+
" return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)\n",
|
65 |
+
"\n",
|
66 |
+
"\n",
|
67 |
+
"def scan_checkpoint(cp_dir, prefix):\n",
|
68 |
+
" pattern = os.path.join(cp_dir, prefix + '*')\n",
|
69 |
+
" cp_list = glob.glob(pattern)\n",
|
70 |
+
" if len(cp_list) == 0:\n",
|
71 |
+
" return ''\n",
|
72 |
+
" return sorted(cp_list)[-1]"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": 3,
|
78 |
+
"id": "003b1249",
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [],
|
81 |
+
"source": [
|
82 |
+
"F0_model = JDCNet(num_class=1, seq_len=192)"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "code",
|
87 |
+
"execution_count": 4,
|
88 |
+
"id": "321eb3b5",
|
89 |
+
"metadata": {},
|
90 |
+
"outputs": [],
|
91 |
+
"source": [
|
92 |
+
"cp_path = \"/home/ubuntu/RINGFORMER/cp_ringformer_44.1khz\""
|
93 |
+
]
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "code",
|
97 |
+
"execution_count": 5,
|
98 |
+
"id": "3dcc2764",
|
99 |
+
"metadata": {},
|
100 |
+
"outputs": [
|
101 |
+
{
|
102 |
+
"data": {
|
103 |
+
"text/plain": [
|
104 |
+
"{'F0_path': '/home/ubuntu/Darya_Speech/Utils/JDC/epoch_00079.pth',\n",
|
105 |
+
" 'resblock': '1',\n",
|
106 |
+
" 'num_gpus': 1,\n",
|
107 |
+
" 'batch_size': 18,\n",
|
108 |
+
" 'learning_rate': 0.0002,\n",
|
109 |
+
" 'adam_b1': 0.8,\n",
|
110 |
+
" 'adam_b2': 0.99,\n",
|
111 |
+
" 'lr_decay': 0.999,\n",
|
112 |
+
" 'seed': 1234,\n",
|
113 |
+
" 'upsample_rates': [16, 8],\n",
|
114 |
+
" 'upsample_kernel_sizes': [32, 16],\n",
|
115 |
+
" 'upsample_initial_channel': 512,\n",
|
116 |
+
" 'resblock_kernel_sizes': [3, 7, 11],\n",
|
117 |
+
" 'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n",
|
118 |
+
" 'gen_istft_n_fft': 32,\n",
|
119 |
+
" 'gen_istft_hop_size': 4,\n",
|
120 |
+
" 'segment_size': 65536,\n",
|
121 |
+
" 'num_mels': 128,\n",
|
122 |
+
" 'n_fft': 2048,\n",
|
123 |
+
" 'hop_size': 512,\n",
|
124 |
+
" 'win_size': 2048,\n",
|
125 |
+
" 'sampling_rate': 44100,\n",
|
126 |
+
" 'fmin': 0,\n",
|
127 |
+
" 'fmax': None,\n",
|
128 |
+
" 'fmax_for_loss': None,\n",
|
129 |
+
" 'num_workers': 8,\n",
|
130 |
+
" 'dist_config': {'dist_backend': 'nccl',\n",
|
131 |
+
" 'dist_url': 'tcp://localhost:54321',\n",
|
132 |
+
" 'world_size': 1}}"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
"execution_count": 5,
|
136 |
+
"metadata": {},
|
137 |
+
"output_type": "execute_result"
|
138 |
+
}
|
139 |
+
],
|
140 |
+
"source": [
|
141 |
+
"with open(cp_path + \"/config.json\") as f:\n",
|
142 |
+
" data = f.read()\n",
|
143 |
+
"\n",
|
144 |
+
"json_config = json.loads(data)\n",
|
145 |
+
"h = AttrDict(json_config)\n",
|
146 |
+
"h"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "code",
|
151 |
+
"execution_count": 6,
|
152 |
+
"id": "a4c78cb6",
|
153 |
+
"metadata": {},
|
154 |
+
"outputs": [],
|
155 |
+
"source": [
|
156 |
+
"# device = torch.device('cuda:{:d}'.format(0))\n",
|
157 |
+
"\n",
|
158 |
+
"device = 'cuda:0'"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "code",
|
163 |
+
"execution_count": 7,
|
164 |
+
"id": "5a782adb",
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [
|
167 |
+
{
|
168 |
+
"name": "stderr",
|
169 |
+
"output_type": "stream",
|
170 |
+
"text": [
|
171 |
+
"/home/ubuntu/miniconda3/envs/respair/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n",
|
172 |
+
" WeightNorm.apply(module, name, dim)\n"
|
173 |
+
]
|
174 |
+
}
|
175 |
+
],
|
176 |
+
"source": [
|
177 |
+
"generator = Generator(h, F0_model).to(device)\n",
|
178 |
+
"stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)"
|
179 |
+
]
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"cell_type": "code",
|
183 |
+
"execution_count": 8,
|
184 |
+
"id": "6f0a7c64",
|
185 |
+
"metadata": {},
|
186 |
+
"outputs": [
|
187 |
+
{
|
188 |
+
"name": "stdout",
|
189 |
+
"output_type": "stream",
|
190 |
+
"text": [
|
191 |
+
"Loading '/home/ubuntu/RINGFORMER/cp_ringformer_44.1khz/g_00017000'\n"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"name": "stderr",
|
196 |
+
"output_type": "stream",
|
197 |
+
"text": [
|
198 |
+
"/tmp/ipykernel_3972638/3295719764.py:8: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
199 |
+
" checkpoint_dict = torch.load(filepath, map_location=device)\n"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"name": "stdout",
|
204 |
+
"output_type": "stream",
|
205 |
+
"text": [
|
206 |
+
"Complete.\n",
|
207 |
+
"Removing weight norm...\n"
|
208 |
+
]
|
209 |
+
}
|
210 |
+
],
|
211 |
+
"source": [
|
212 |
+
"cp_g = scan_checkpoint(cp_path, 'g_')\n",
|
213 |
+
"state_dict_g = load_checkpoint(cp_g, device)\n",
|
214 |
+
"generator.load_state_dict(state_dict_g['generator'])\n",
|
215 |
+
"generator.remove_weight_norm()\n",
|
216 |
+
"_ = generator.eval()"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "markdown",
|
221 |
+
"id": "a115a967",
|
222 |
+
"metadata": {},
|
223 |
+
"source": [
|
224 |
+
"### Resynthesis"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"execution_count": null,
|
230 |
+
"id": "cbeee500",
|
231 |
+
"metadata": {},
|
232 |
+
"outputs": [],
|
233 |
+
"source": [
|
234 |
+
"import torchaudio\n",
|
235 |
+
"from librosa.filters import mel as librosa_mel_fn\n",
|
236 |
+
"from IPython.display import Audio\n",
|
237 |
+
"import librosa\n",
|
238 |
+
"\n",
|
239 |
+
"to_mel = torchaudio.transforms.MelSpectrogram(\n",
|
240 |
+
" n_mels=128, n_fft=2048, win_length=2048, hop_length=512, sample_rate=44100, power=2.5)\n",
|
241 |
+
"mean, std = -4, 4\n",
|
242 |
+
"\n",
|
243 |
+
"def preprocess(wave):\n",
|
244 |
+
" wave_tensor = torch.FloatTensor(wav)\n",
|
245 |
+
" mel_tensor = to_mel(wave_tensor)\n",
|
246 |
+
" mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
|
247 |
+
" return mel_tensor.to('cuda:0')\n",
|
248 |
+
"\n",
|
249 |
+
"\n",
|
250 |
+
"wav = librosa.load(\"/your.wav\", sr=44100)[0]\n",
|
251 |
+
"\n",
|
252 |
+
"x = preprocess(wav)\n",
|
253 |
+
"print(x.shape)\n",
|
254 |
+
"\n",
|
255 |
+
"n = 1\n",
|
256 |
+
"xxx = torch.load(\"/home/ubuntu/RINGFORMER/gt.pt\").to('cuda:0')[n:n+1,:,:]\n",
|
257 |
+
"with torch.no_grad():\n",
|
258 |
+
"\n",
|
259 |
+
" spec, phase = generator(xxx)\n",
|
260 |
+
" y_g_hat = stft.inverse(spec, phase)\n",
|
261 |
+
" audio = y_g_hat.squeeze()\n",
|
262 |
+
" # audio = audio * MAX_WAV_VALUE\n",
|
263 |
+
" audio = audio.cpu().numpy()\n",
|
264 |
+
"\n",
|
265 |
+
"\n",
|
266 |
+
"print('Synthesized:')\n",
|
267 |
+
"display(Audio(audio, rate=44100))"
|
268 |
+
]
|
269 |
+
}
|
270 |
+
],
|
271 |
+
"metadata": {
|
272 |
+
"kernelspec": {
|
273 |
+
"display_name": "respair",
|
274 |
+
"language": "python",
|
275 |
+
"name": "python3"
|
276 |
+
},
|
277 |
+
"language_info": {
|
278 |
+
"codemirror_mode": {
|
279 |
+
"name": "ipython",
|
280 |
+
"version": 3
|
281 |
+
},
|
282 |
+
"file_extension": ".py",
|
283 |
+
"mimetype": "text/x-python",
|
284 |
+
"name": "python",
|
285 |
+
"nbconvert_exporter": "python",
|
286 |
+
"pygments_lexer": "ipython3",
|
287 |
+
"version": "3.11.0"
|
288 |
+
}
|
289 |
+
},
|
290 |
+
"nbformat": 4,
|
291 |
+
"nbformat_minor": 5
|
292 |
+
}
|
RingFormer/meldataset.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
import numpy as np
|
7 |
+
from librosa.util import normalize
|
8 |
+
from scipy.io.wavfile import read
|
9 |
+
import torchaudio
|
10 |
+
import librosa
|
11 |
+
from librosa.filters import mel as librosa_mel_fn
|
12 |
+
|
13 |
+
MAX_WAV_VALUE = 32768.0
|
14 |
+
import soundfile as sf
|
15 |
+
|
16 |
+
|
17 |
+
def normalize_audio(wav):
|
18 |
+
return wav / torch.max(torch.abs(torch.from_numpy(wav))) # Correct peak normalization
|
19 |
+
|
20 |
+
def load_wav(full_path):
|
21 |
+
data, sampling_rate = librosa.load(full_path, sr=None)
|
22 |
+
return data, sampling_rate
|
23 |
+
|
24 |
+
|
25 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
26 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
27 |
+
|
28 |
+
|
29 |
+
def dynamic_range_decompression(x, C=1):
|
30 |
+
return np.exp(x) / C
|
31 |
+
|
32 |
+
|
33 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
34 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
35 |
+
|
36 |
+
|
37 |
+
def dynamic_range_decompression_torch(x, C=1):
|
38 |
+
return torch.exp(x) / C
|
39 |
+
|
40 |
+
|
41 |
+
def spectral_normalize_torch(magnitudes):
|
42 |
+
output = dynamic_range_compression_torch(magnitudes)
|
43 |
+
return output
|
44 |
+
|
45 |
+
|
46 |
+
def spectral_de_normalize_torch(magnitudes):
|
47 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
48 |
+
return output
|
49 |
+
|
50 |
+
|
51 |
+
mel_basis = {}
|
52 |
+
hann_window = {}
|
53 |
+
|
54 |
+
|
55 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
56 |
+
|
57 |
+
y = torch.clamp(y, min=-1, max=1)
|
58 |
+
|
59 |
+
if torch.min(y) < -1.:
|
60 |
+
print('min value is ', torch.min(y))
|
61 |
+
if torch.max(y) > 1.:
|
62 |
+
print('max value is ', torch.max(y))
|
63 |
+
|
64 |
+
global mel_basis, hann_window
|
65 |
+
if fmax not in mel_basis:
|
66 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
67 |
+
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
68 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
69 |
+
|
70 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
71 |
+
y = y.squeeze(1)
|
72 |
+
|
73 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
74 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
75 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
76 |
+
spec = torch.view_as_real(spec)
|
77 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
78 |
+
|
79 |
+
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
80 |
+
spec = spectral_normalize_torch(spec)
|
81 |
+
|
82 |
+
return spec
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
to_mel = torchaudio.transforms.MelSpectrogram(
|
87 |
+
sample_rate=44_100, n_mels=128, n_fft=2048, win_length=2048, hop_length=512)
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
# to_mel = torchaudio.transforms.MelSpectrogram(
|
92 |
+
# sample_rate=24000, n_mels=80, n_fft=2048, win_length=1200, hop_length=300, center='center')
|
93 |
+
|
94 |
+
mean, std = -4, 4
|
95 |
+
|
96 |
+
def preproces(wave,to_mel=to_mel, device='cpu'):
|
97 |
+
|
98 |
+
to_mel = to_mel.to(device)
|
99 |
+
# wave_tensor = torch.from_numpy(wave).float()
|
100 |
+
mel_tensor = to_mel(wave)
|
101 |
+
mel_tensor = (torch.log(1e-5 + mel_tensor) - mean) / std
|
102 |
+
return mel_tensor
|
103 |
+
|
104 |
+
|
105 |
+
def get_dataset_filelist(a):
|
106 |
+
with open(a.input_training_file, 'r', encoding='utf-8') as fi:
|
107 |
+
training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ('' if '' not in x else ''))
|
108 |
+
for x in fi.read().split('\n') if len(x) > 0]
|
109 |
+
|
110 |
+
with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
|
111 |
+
validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ('' if '' not in x else ''))
|
112 |
+
for x in fi.read().split('\n') if len(x) > 0]
|
113 |
+
return training_files, validation_files
|
114 |
+
|
115 |
+
|
116 |
+
class MelDataset(torch.utils.data.Dataset):
|
117 |
+
def __init__(self, training_files, segment_size, n_fft, num_mels,
|
118 |
+
hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
|
119 |
+
device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
|
120 |
+
self.audio_files = training_files
|
121 |
+
random.seed(1234)
|
122 |
+
if shuffle:
|
123 |
+
random.shuffle(self.audio_files)
|
124 |
+
self.segment_size = segment_size
|
125 |
+
self.sampling_rate = sampling_rate
|
126 |
+
self.split = split
|
127 |
+
self.n_fft = n_fft
|
128 |
+
self.num_mels = num_mels
|
129 |
+
self.hop_size = hop_size
|
130 |
+
self.win_size = win_size
|
131 |
+
self.fmin = fmin
|
132 |
+
self.fmax = fmax
|
133 |
+
self.fmax_loss = fmax_loss
|
134 |
+
self.cached_wav = None
|
135 |
+
self.n_cache_reuse = n_cache_reuse
|
136 |
+
self._cache_ref_count = 0
|
137 |
+
self.device = device
|
138 |
+
self.fine_tuning = fine_tuning
|
139 |
+
self.base_mels_path = base_mels_path
|
140 |
+
|
141 |
+
def __getitem__(self, index):
|
142 |
+
filename = self.audio_files[index]
|
143 |
+
if self._cache_ref_count == 0:
|
144 |
+
audio, sampling_rate = load_wav(filename)
|
145 |
+
|
146 |
+
|
147 |
+
self.cached_wav = audio
|
148 |
+
if sampling_rate != self.sampling_rate:
|
149 |
+
raise ValueError("{} SR doesn't match target {} SR".format(
|
150 |
+
sampling_rate, self.sampling_rate))
|
151 |
+
self._cache_ref_count = self.n_cache_reuse
|
152 |
+
else:
|
153 |
+
audio = self.cached_wav
|
154 |
+
self._cache_ref_count -= 1
|
155 |
+
|
156 |
+
audio = torch.FloatTensor(audio)
|
157 |
+
audio = audio.unsqueeze(0)
|
158 |
+
|
159 |
+
if not self.fine_tuning:
|
160 |
+
if self.split:
|
161 |
+
if audio.size(1) >= self.segment_size:
|
162 |
+
max_audio_start = audio.size(1) - self.segment_size
|
163 |
+
audio_start = random.randint(0, max_audio_start)
|
164 |
+
audio = audio[:, audio_start:audio_start+self.segment_size]
|
165 |
+
else:
|
166 |
+
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
167 |
+
|
168 |
+
# mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
169 |
+
# self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
|
170 |
+
# center=False)
|
171 |
+
|
172 |
+
mel = preproces(audio)
|
173 |
+
else:
|
174 |
+
mel = np.load(
|
175 |
+
os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
|
176 |
+
mel = torch.from_numpy(mel)
|
177 |
+
|
178 |
+
if len(mel.shape) < 3:
|
179 |
+
mel = mel.unsqueeze(0)
|
180 |
+
|
181 |
+
if self.split:
|
182 |
+
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
183 |
+
|
184 |
+
if audio.size(1) >= self.segment_size:
|
185 |
+
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
186 |
+
mel = mel[:, :, mel_start:mel_start + frames_per_seg]
|
187 |
+
audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
|
188 |
+
else:
|
189 |
+
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
|
190 |
+
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
191 |
+
|
192 |
+
mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
193 |
+
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
|
194 |
+
center=False)
|
195 |
+
|
196 |
+
# mel_loss = mel_spectrogram(audio)
|
197 |
+
if mel.shape[-1] != mel_loss.shape[-1]:
|
198 |
+
mel = mel[..., :mel_loss.shape[-1]]
|
199 |
+
|
200 |
+
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
201 |
+
|
202 |
+
def __len__(self):
|
203 |
+
return len(self.audio_files)
|
RingFormer/models.py
ADDED
@@ -0,0 +1,943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
6 |
+
from utils import init_weights, get_padding
|
7 |
+
import numpy as np
|
8 |
+
from stft import TorchSTFT
|
9 |
+
import torchaudio
|
10 |
+
from nnAudio import features
|
11 |
+
from einops import rearrange
|
12 |
+
from norm2d import NormConv2d
|
13 |
+
from utils import get_padding
|
14 |
+
from munch import Munch
|
15 |
+
from conformer import Conformer
|
16 |
+
|
17 |
+
LRELU_SLOPE = 0.1
|
18 |
+
|
19 |
+
|
20 |
+
def get_2d_padding(kernel_size, dilation=(1, 1)):
|
21 |
+
return (
|
22 |
+
((kernel_size[0] - 1) * dilation[0]) // 2,
|
23 |
+
((kernel_size[1] - 1) * dilation[1]) // 2,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
class ResBlock1(torch.nn.Module):
|
29 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
30 |
+
super(ResBlock1, self).__init__()
|
31 |
+
self.h = h
|
32 |
+
self.convs1 = nn.ModuleList([
|
33 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
34 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
35 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
36 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
37 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
38 |
+
padding=get_padding(kernel_size, dilation[2])))
|
39 |
+
])
|
40 |
+
self.convs1.apply(init_weights)
|
41 |
+
|
42 |
+
self.convs2 = nn.ModuleList([
|
43 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
44 |
+
padding=get_padding(kernel_size, 1))),
|
45 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
46 |
+
padding=get_padding(kernel_size, 1))),
|
47 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
48 |
+
padding=get_padding(kernel_size, 1)))
|
49 |
+
])
|
50 |
+
self.convs2.apply(init_weights)
|
51 |
+
|
52 |
+
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
53 |
+
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
54 |
+
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.alpha1, self.alpha2):
|
58 |
+
xt = x + (1 / a1) * (torch.sin(a1 * x) ** 2) # Snake1D
|
59 |
+
xt = c1(xt)
|
60 |
+
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
61 |
+
xt = c2(xt)
|
62 |
+
x = xt + x
|
63 |
+
return x
|
64 |
+
|
65 |
+
def remove_weight_norm(self):
|
66 |
+
for l in self.convs1:
|
67 |
+
remove_weight_norm(l)
|
68 |
+
for l in self.convs2:
|
69 |
+
remove_weight_norm(l)
|
70 |
+
|
71 |
+
class ResBlock1_old(torch.nn.Module):
|
72 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
73 |
+
super(ResBlock1, self).__init__()
|
74 |
+
self.h = h
|
75 |
+
self.convs1 = nn.ModuleList([
|
76 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
77 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
78 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
79 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
80 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
81 |
+
padding=get_padding(kernel_size, dilation[2])))
|
82 |
+
])
|
83 |
+
self.convs1.apply(init_weights)
|
84 |
+
|
85 |
+
self.convs2 = nn.ModuleList([
|
86 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
87 |
+
padding=get_padding(kernel_size, 1))),
|
88 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
89 |
+
padding=get_padding(kernel_size, 1))),
|
90 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
91 |
+
padding=get_padding(kernel_size, 1)))
|
92 |
+
])
|
93 |
+
self.convs2.apply(init_weights)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
97 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
98 |
+
xt = c1(xt)
|
99 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
100 |
+
xt = c2(xt)
|
101 |
+
x = xt + x
|
102 |
+
return x
|
103 |
+
|
104 |
+
def remove_weight_norm(self):
|
105 |
+
for l in self.convs1:
|
106 |
+
remove_weight_norm(l)
|
107 |
+
for l in self.convs2:
|
108 |
+
remove_weight_norm(l)
|
109 |
+
|
110 |
+
|
111 |
+
class ResBlock2(torch.nn.Module):
|
112 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
113 |
+
super(ResBlock2, self).__init__()
|
114 |
+
self.h = h
|
115 |
+
self.convs = nn.ModuleList([
|
116 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
117 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
118 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
119 |
+
padding=get_padding(kernel_size, dilation[1])))
|
120 |
+
])
|
121 |
+
self.convs.apply(init_weights)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
for c in self.convs:
|
125 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
126 |
+
xt = c(xt)
|
127 |
+
x = xt + x
|
128 |
+
return x
|
129 |
+
|
130 |
+
def remove_weight_norm(self):
|
131 |
+
for l in self.convs:
|
132 |
+
remove_weight_norm(l)
|
133 |
+
|
134 |
+
|
135 |
+
class SineGen(torch.nn.Module):
|
136 |
+
""" Definition of sine generator
|
137 |
+
SineGen(samp_rate, harmonic_num = 0,
|
138 |
+
sine_amp = 0.1, noise_std = 0.003,
|
139 |
+
voiced_threshold = 0,
|
140 |
+
flag_for_pulse=False)
|
141 |
+
samp_rate: sampling rate in Hz
|
142 |
+
harmonic_num: number of harmonic overtones (default 0)
|
143 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
144 |
+
noise_std: std of Gaussian noise (default 0.003)
|
145 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
146 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
147 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
148 |
+
segment is always sin(np.pi) or cos(0)
|
149 |
+
"""
|
150 |
+
|
151 |
+
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
152 |
+
sine_amp=0.1, noise_std=0.003,
|
153 |
+
voiced_threshold=0,
|
154 |
+
flag_for_pulse=False):
|
155 |
+
super(SineGen, self).__init__()
|
156 |
+
self.sine_amp = sine_amp
|
157 |
+
self.noise_std = noise_std
|
158 |
+
self.harmonic_num = harmonic_num
|
159 |
+
self.dim = self.harmonic_num + 1
|
160 |
+
self.sampling_rate = samp_rate
|
161 |
+
self.voiced_threshold = voiced_threshold
|
162 |
+
self.flag_for_pulse = flag_for_pulse
|
163 |
+
self.upsample_scale = upsample_scale
|
164 |
+
|
165 |
+
def _f02uv(self, f0):
|
166 |
+
# generate uv signal
|
167 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
168 |
+
return uv
|
169 |
+
|
170 |
+
def _f02sine(self, f0_values):
|
171 |
+
""" f0_values: (batchsize, length, dim)
|
172 |
+
where dim indicates fundamental tone and overtones
|
173 |
+
"""
|
174 |
+
# convert to F0 in rad. The interger part n can be ignored
|
175 |
+
# because 2 * np.pi * n doesn't affect phase
|
176 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
177 |
+
|
178 |
+
# initial phase noise (no noise for fundamental component)
|
179 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
180 |
+
device=f0_values.device)
|
181 |
+
rand_ini[:, 0] = 0
|
182 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
183 |
+
|
184 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
185 |
+
if not self.flag_for_pulse:
|
186 |
+
# # for normal case
|
187 |
+
|
188 |
+
# # To prevent torch.cumsum numerical overflow,
|
189 |
+
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
190 |
+
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
191 |
+
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
192 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
193 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
194 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
195 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
196 |
+
|
197 |
+
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
198 |
+
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
199 |
+
scale_factor=1/self.upsample_scale,
|
200 |
+
mode="linear").transpose(1, 2)
|
201 |
+
|
202 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
203 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
204 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
205 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
206 |
+
|
207 |
+
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
208 |
+
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
209 |
+
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
210 |
+
sines = torch.sin(phase)
|
211 |
+
|
212 |
+
else:
|
213 |
+
# If necessary, make sure that the first time step of every
|
214 |
+
# voiced segments is sin(pi) or cos(0)
|
215 |
+
# This is used for pulse-train generation
|
216 |
+
|
217 |
+
# identify the last time step in unvoiced segments
|
218 |
+
uv = self._f02uv(f0_values)
|
219 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
220 |
+
uv_1[:, -1, :] = 1
|
221 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
222 |
+
|
223 |
+
# get the instantanouse phase
|
224 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
225 |
+
# different batch needs to be processed differently
|
226 |
+
for idx in range(f0_values.shape[0]):
|
227 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
228 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
229 |
+
# stores the accumulation of i.phase within
|
230 |
+
# each voiced segments
|
231 |
+
tmp_cumsum[idx, :, :] = 0
|
232 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
233 |
+
|
234 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
235 |
+
# within the previous voiced segment.
|
236 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
237 |
+
|
238 |
+
# get the sines
|
239 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
240 |
+
return sines
|
241 |
+
|
242 |
+
def forward(self, f0):
|
243 |
+
""" sine_tensor, uv = forward(f0)
|
244 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
245 |
+
f0 for unvoiced steps should be 0
|
246 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
247 |
+
output uv: tensor(batchsize=1, length, 1)
|
248 |
+
"""
|
249 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
250 |
+
device=f0.device)
|
251 |
+
# fundamental component
|
252 |
+
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
253 |
+
|
254 |
+
# generate sine waveforms
|
255 |
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
256 |
+
|
257 |
+
# generate uv signal
|
258 |
+
# uv = torch.ones(f0.shape)
|
259 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
260 |
+
uv = self._f02uv(f0)
|
261 |
+
|
262 |
+
# noise: for unvoiced should be similar to sine_amp
|
263 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
264 |
+
# . for voiced regions is self.noise_std
|
265 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
266 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
267 |
+
|
268 |
+
# first: set the unvoiced part to 0 by uv
|
269 |
+
# then: additive noise
|
270 |
+
sine_waves = sine_waves * uv + noise
|
271 |
+
return sine_waves, uv, noise
|
272 |
+
|
273 |
+
|
274 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
275 |
+
""" SourceModule for hn-nsf
|
276 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
277 |
+
add_noise_std=0.003, voiced_threshod=0)
|
278 |
+
sampling_rate: sampling_rate in Hz
|
279 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
280 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
281 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
282 |
+
note that amplitude of noise in unvoiced is decided
|
283 |
+
by sine_amp
|
284 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
285 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
286 |
+
F0_sampled (batchsize, length, 1)
|
287 |
+
Sine_source (batchsize, length, 1)
|
288 |
+
noise_source (batchsize, length 1)
|
289 |
+
uv (batchsize, length, 1)
|
290 |
+
"""
|
291 |
+
|
292 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
293 |
+
add_noise_std=0.003, voiced_threshod=0):
|
294 |
+
super(SourceModuleHnNSF, self).__init__()
|
295 |
+
|
296 |
+
self.sine_amp = sine_amp
|
297 |
+
self.noise_std = add_noise_std
|
298 |
+
|
299 |
+
# to produce sine waveforms
|
300 |
+
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
301 |
+
sine_amp, add_noise_std, voiced_threshod)
|
302 |
+
|
303 |
+
# to merge source harmonics into a single excitation
|
304 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
305 |
+
self.l_tanh = torch.nn.Tanh()
|
306 |
+
|
307 |
+
def forward(self, x):
|
308 |
+
"""
|
309 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
310 |
+
F0_sampled (batchsize, length, 1)
|
311 |
+
Sine_source (batchsize, length, 1)
|
312 |
+
noise_source (batchsize, length 1)
|
313 |
+
"""
|
314 |
+
# source for harmonic branch
|
315 |
+
with torch.no_grad():
|
316 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
317 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
318 |
+
|
319 |
+
# source for noise branch, in the same shape as uv
|
320 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
321 |
+
return sine_merge, noise, uv
|
322 |
+
def padDiff(x):
|
323 |
+
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
|
324 |
+
|
325 |
+
|
326 |
+
|
327 |
+
class Generator(torch.nn.Module):
|
328 |
+
def __init__(self, h, F0_model):
|
329 |
+
super(Generator, self).__init__()
|
330 |
+
self.h = h
|
331 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
332 |
+
self.num_upsamples = len(h.upsample_rates)
|
333 |
+
self.conv_pre = weight_norm(Conv1d(128, h.upsample_initial_channel, 7, 1, padding=3))
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
338 |
+
|
339 |
+
self.m_source = SourceModuleHnNSF(
|
340 |
+
sampling_rate=h.sampling_rate,
|
341 |
+
upsample_scale=np.prod(h.upsample_rates) * h.gen_istft_hop_size,
|
342 |
+
harmonic_num=8, voiced_threshod=10)
|
343 |
+
|
344 |
+
self.f0_upsamp = torch.nn.Upsample(
|
345 |
+
scale_factor=np.prod(h.upsample_rates) * h.gen_istft_hop_size)
|
346 |
+
self.noise_convs = nn.ModuleList()
|
347 |
+
self.noise_res = nn.ModuleList()
|
348 |
+
|
349 |
+
self.F0_model = F0_model
|
350 |
+
|
351 |
+
self.ups = nn.ModuleList()
|
352 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
353 |
+
self.ups.append(weight_norm(
|
354 |
+
ConvTranspose1d(h.upsample_initial_channel//(2**i),
|
355 |
+
h.upsample_initial_channel//(2**(i+1)),
|
356 |
+
k,
|
357 |
+
u,
|
358 |
+
padding=(k-u)//2)))
|
359 |
+
|
360 |
+
c_cur = h.upsample_initial_channel // (2 ** (i + 1))
|
361 |
+
|
362 |
+
if i + 1 < len(h.upsample_rates): #
|
363 |
+
stride_f0 = np.prod(h.upsample_rates[i + 1:])
|
364 |
+
self.noise_convs.append(Conv1d(
|
365 |
+
h.gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
366 |
+
self.noise_res.append(resblock(h, c_cur, 7, [1,3,5]))
|
367 |
+
else:
|
368 |
+
self.noise_convs.append(Conv1d(h.gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
369 |
+
self.noise_res.append(resblock(h, c_cur, 11, [1,3,5]))
|
370 |
+
|
371 |
+
self.alphas = nn.ParameterList()
|
372 |
+
self.alphas.append(nn.Parameter(torch.ones(1, h.upsample_initial_channel, 1)))
|
373 |
+
self.resblocks = nn.ModuleList()
|
374 |
+
for i in range(len(self.ups)):
|
375 |
+
ch = h.upsample_initial_channel//(2**(i+1))
|
376 |
+
self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
|
377 |
+
for j, (k, d) in enumerate(
|
378 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
379 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
380 |
+
|
381 |
+
|
382 |
+
self.conformers = nn.ModuleList()
|
383 |
+
self.post_n_fft = h.gen_istft_n_fft
|
384 |
+
self.conv_post = weight_norm(Conv1d(128, self.post_n_fft + 2, 7, 1, padding=3))
|
385 |
+
|
386 |
+
for i in range(len(self.ups)):
|
387 |
+
ch = h.upsample_initial_channel // (2**i)
|
388 |
+
self.conformers.append(
|
389 |
+
Conformer(
|
390 |
+
dim=ch,
|
391 |
+
depth=2,
|
392 |
+
dim_head=64,
|
393 |
+
heads=8,
|
394 |
+
ff_mult=4,
|
395 |
+
conv_expansion_factor=2,
|
396 |
+
conv_kernel_size=31,
|
397 |
+
attn_dropout=0.1,
|
398 |
+
ff_dropout=0.1,
|
399 |
+
conv_dropout=0.1,
|
400 |
+
# device=self.device
|
401 |
+
)
|
402 |
+
)
|
403 |
+
|
404 |
+
self.ups.apply(init_weights)
|
405 |
+
self.conv_post.apply(init_weights)
|
406 |
+
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
407 |
+
self.stft = TorchSTFT(filter_length=h.gen_istft_n_fft,
|
408 |
+
hop_length=h.gen_istft_hop_size,
|
409 |
+
win_length=h.gen_istft_n_fft)
|
410 |
+
|
411 |
+
|
412 |
+
|
413 |
+
def forward(self, x):
|
414 |
+
|
415 |
+
|
416 |
+
|
417 |
+
f0, _, _ = self.F0_model(x.unsqueeze(1))
|
418 |
+
if len(f0.shape) == 1:
|
419 |
+
f0 = f0.unsqueeze(0)
|
420 |
+
|
421 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
422 |
+
|
423 |
+
har_source, _, _ = self.m_source(f0)
|
424 |
+
har_source = har_source.transpose(1, 2).squeeze(1)
|
425 |
+
har_spec, har_phase = self.stft.transform(har_source)
|
426 |
+
har = torch.cat([har_spec, har_phase], dim=1)
|
427 |
+
|
428 |
+
|
429 |
+
x = self.conv_pre(x)
|
430 |
+
|
431 |
+
for i in range(self.num_upsamples):
|
432 |
+
|
433 |
+
x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
|
434 |
+
x = rearrange(x, "b f t -> b t f")
|
435 |
+
|
436 |
+
x = self.conformers[i](x)
|
437 |
+
|
438 |
+
x = rearrange(x, "b t f -> b f t")
|
439 |
+
|
440 |
+
# x = F.leaky_relu(x, LRELU_SLOPE)
|
441 |
+
x_source = self.noise_convs[i](har)
|
442 |
+
x_source = self.noise_res[i](x_source)
|
443 |
+
|
444 |
+
x = self.ups[i](x)
|
445 |
+
if i == self.num_upsamples - 1:
|
446 |
+
x = self.reflection_pad(x)
|
447 |
+
|
448 |
+
x = x + x_source
|
449 |
+
|
450 |
+
|
451 |
+
xs = None
|
452 |
+
for j in range(self.num_kernels):
|
453 |
+
if xs is None:
|
454 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
455 |
+
else:
|
456 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
457 |
+
x = xs / self.num_kernels
|
458 |
+
# x = F.leaky_relu(x)
|
459 |
+
|
460 |
+
|
461 |
+
x = x + (1 / self.alphas[i + 1]) * (torch.sin(self.alphas[i + 1] * x) ** 2)
|
462 |
+
|
463 |
+
x = self.conv_post(x)
|
464 |
+
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]).to(x.device)
|
465 |
+
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :]).to(x.device)
|
466 |
+
|
467 |
+
return spec, phase
|
468 |
+
|
469 |
+
def remove_weight_norm(self):
|
470 |
+
print("Removing weight norm...")
|
471 |
+
for l in self.ups:
|
472 |
+
remove_weight_norm(l)
|
473 |
+
for l in self.resblocks:
|
474 |
+
l.remove_weight_norm()
|
475 |
+
remove_weight_norm(self.conv_pre)
|
476 |
+
remove_weight_norm(self.conv_post)
|
477 |
+
|
478 |
+
|
479 |
+
|
480 |
+
def stft(x, fft_size, hop_size, win_length, window):
|
481 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
482 |
+
Args:
|
483 |
+
x (Tensor): Input signal tensor (B, T).
|
484 |
+
fft_size (int): FFT size.
|
485 |
+
hop_size (int): Hop size.
|
486 |
+
win_length (int): Window length.
|
487 |
+
window (str): Window function type.
|
488 |
+
Returns:
|
489 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
490 |
+
"""
|
491 |
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
|
492 |
+
return_complex=True)
|
493 |
+
real = x_stft[..., 0]
|
494 |
+
imag = x_stft[..., 1]
|
495 |
+
|
496 |
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
497 |
+
return torch.abs(x_stft).transpose(2, 1)
|
498 |
+
|
499 |
+
class SpecDiscriminator(nn.Module):
|
500 |
+
"""docstring for Discriminator."""
|
501 |
+
|
502 |
+
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
|
503 |
+
super(SpecDiscriminator, self).__init__()
|
504 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
505 |
+
self.fft_size = fft_size
|
506 |
+
self.shift_size = shift_size
|
507 |
+
self.win_length = win_length
|
508 |
+
self.window = getattr(torch, window)(win_length)
|
509 |
+
self.discriminators = nn.ModuleList([
|
510 |
+
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
511 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
512 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
513 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
514 |
+
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
|
515 |
+
])
|
516 |
+
|
517 |
+
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
518 |
+
|
519 |
+
def forward(self, y):
|
520 |
+
|
521 |
+
fmap = []
|
522 |
+
y = y.squeeze(1)
|
523 |
+
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
|
524 |
+
y = y.unsqueeze(1)
|
525 |
+
for i, d in enumerate(self.discriminators):
|
526 |
+
y = d(y)
|
527 |
+
y = F.leaky_relu(y, LRELU_SLOPE)
|
528 |
+
fmap.append(y)
|
529 |
+
|
530 |
+
y = self.out(y)
|
531 |
+
fmap.append(y)
|
532 |
+
|
533 |
+
return torch.flatten(y, 1, -1), fmap
|
534 |
+
|
535 |
+
# class MultiResSpecDiscriminator(torch.nn.Module):
|
536 |
+
|
537 |
+
# def __init__(self,
|
538 |
+
# fft_sizes=[1024, 2048, 512],
|
539 |
+
# hop_sizes=[120, 240, 50],
|
540 |
+
# win_lengths=[600, 1200, 240],
|
541 |
+
# window="hann_window"):
|
542 |
+
|
543 |
+
# super(MultiResSpecDiscriminator, self).__init__()
|
544 |
+
# self.discriminators = nn.ModuleList([
|
545 |
+
# SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
546 |
+
# SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
547 |
+
# SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
|
548 |
+
# ])
|
549 |
+
|
550 |
+
# def forward(self, y, y_hat):
|
551 |
+
# y_d_rs = []
|
552 |
+
# y_d_gs = []
|
553 |
+
# fmap_rs = []
|
554 |
+
# fmap_gs = []
|
555 |
+
# for i, d in enumerate(self.discriminators):
|
556 |
+
# y_d_r, fmap_r = d(y)
|
557 |
+
# y_d_g, fmap_g = d(y_hat)
|
558 |
+
# y_d_rs.append(y_d_r)
|
559 |
+
# fmap_rs.append(fmap_r)
|
560 |
+
# y_d_gs.append(y_d_g)
|
561 |
+
# fmap_gs.append(fmap_g)
|
562 |
+
|
563 |
+
# return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
564 |
+
|
565 |
+
|
566 |
+
class DiscriminatorP(torch.nn.Module):
|
567 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
568 |
+
super(DiscriminatorP, self).__init__()
|
569 |
+
self.period = period
|
570 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
571 |
+
self.convs = nn.ModuleList([
|
572 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
573 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
574 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
575 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
576 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
577 |
+
])
|
578 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
579 |
+
|
580 |
+
def forward(self, x):
|
581 |
+
fmap = []
|
582 |
+
|
583 |
+
# 1d to 2d
|
584 |
+
b, c, t = x.shape
|
585 |
+
if t % self.period != 0: # pad first
|
586 |
+
n_pad = self.period - (t % self.period)
|
587 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
588 |
+
t = t + n_pad
|
589 |
+
x = x.view(b, c, t // self.period, self.period)
|
590 |
+
|
591 |
+
for l in self.convs:
|
592 |
+
x = l(x)
|
593 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
594 |
+
fmap.append(x)
|
595 |
+
x = self.conv_post(x)
|
596 |
+
fmap.append(x)
|
597 |
+
x = torch.flatten(x, 1, -1)
|
598 |
+
|
599 |
+
return x, fmap
|
600 |
+
|
601 |
+
|
602 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
603 |
+
def __init__(self):
|
604 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
605 |
+
self.discriminators = nn.ModuleList([
|
606 |
+
DiscriminatorP(2),
|
607 |
+
DiscriminatorP(3),
|
608 |
+
DiscriminatorP(5),
|
609 |
+
DiscriminatorP(7),
|
610 |
+
DiscriminatorP(11),
|
611 |
+
])
|
612 |
+
|
613 |
+
def forward(self, y, y_hat):
|
614 |
+
y_d_rs = []
|
615 |
+
y_d_gs = []
|
616 |
+
fmap_rs = []
|
617 |
+
fmap_gs = []
|
618 |
+
for i, d in enumerate(self.discriminators):
|
619 |
+
y_d_r, fmap_r = d(y)
|
620 |
+
y_d_g, fmap_g = d(y_hat)
|
621 |
+
y_d_rs.append(y_d_r)
|
622 |
+
fmap_rs.append(fmap_r)
|
623 |
+
y_d_gs.append(y_d_g)
|
624 |
+
fmap_gs.append(fmap_g)
|
625 |
+
|
626 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
627 |
+
|
628 |
+
|
629 |
+
class DiscriminatorS(torch.nn.Module):
|
630 |
+
def __init__(self, use_spectral_norm=False):
|
631 |
+
super(DiscriminatorS, self).__init__()
|
632 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
633 |
+
self.convs = nn.ModuleList([
|
634 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
635 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
636 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
637 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
638 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
639 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
640 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
641 |
+
])
|
642 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
643 |
+
|
644 |
+
def forward(self, x):
|
645 |
+
fmap = []
|
646 |
+
for l in self.convs:
|
647 |
+
x = l(x)
|
648 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
649 |
+
fmap.append(x)
|
650 |
+
x = self.conv_post(x)
|
651 |
+
fmap.append(x)
|
652 |
+
x = torch.flatten(x, 1, -1)
|
653 |
+
|
654 |
+
return x, fmap
|
655 |
+
|
656 |
+
|
657 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
658 |
+
def __init__(self):
|
659 |
+
super(MultiScaleDiscriminator, self).__init__()
|
660 |
+
self.discriminators = nn.ModuleList([
|
661 |
+
DiscriminatorS(use_spectral_norm=True),
|
662 |
+
DiscriminatorS(),
|
663 |
+
DiscriminatorS(),
|
664 |
+
])
|
665 |
+
self.meanpools = nn.ModuleList([
|
666 |
+
AvgPool1d(4, 2, padding=2),
|
667 |
+
AvgPool1d(4, 2, padding=2)
|
668 |
+
])
|
669 |
+
|
670 |
+
def forward(self, y, y_hat):
|
671 |
+
y_d_rs = []
|
672 |
+
y_d_gs = []
|
673 |
+
fmap_rs = []
|
674 |
+
fmap_gs = []
|
675 |
+
for i, d in enumerate(self.discriminators):
|
676 |
+
if i != 0:
|
677 |
+
y = self.meanpools[i-1](y)
|
678 |
+
y_hat = self.meanpools[i-1](y_hat)
|
679 |
+
y_d_r, fmap_r = d(y)
|
680 |
+
y_d_g, fmap_g = d(y_hat)
|
681 |
+
y_d_rs.append(y_d_r)
|
682 |
+
fmap_rs.append(fmap_r)
|
683 |
+
y_d_gs.append(y_d_g)
|
684 |
+
fmap_gs.append(fmap_g)
|
685 |
+
|
686 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
687 |
+
|
688 |
+
|
689 |
+
|
690 |
+
|
691 |
+
|
692 |
+
########################### from ringformer
|
693 |
+
|
694 |
+
multiscale_subband_cfg = {
|
695 |
+
"hop_lengths": [1024, 512, 512], # Doubled to maintain similar time resolution
|
696 |
+
"sampling_rate": 44100, # New sampling rate
|
697 |
+
"filters": 32, # Kept same as it controls initial feature dimension
|
698 |
+
"max_filters": 1024, # Kept same as it's a maximum limit
|
699 |
+
"filters_scale": 1, # Kept same as it's a scaling factor
|
700 |
+
"dilations": [1, 2, 4], # Kept same as they control receptive field growth
|
701 |
+
"in_channels": 1, # Kept same (mono audio)
|
702 |
+
"out_channels": 1, # Kept same (mono audio)
|
703 |
+
"n_octaves": [10, 10, 10], # Increased by 1 to handle higher frequency range
|
704 |
+
"bins_per_octaves": [24, 36, 48], # Kept same as they control frequency resolution
|
705 |
+
}
|
706 |
+
|
707 |
+
|
708 |
+
|
709 |
+
# multiscale_subband_cfg = {
|
710 |
+
# "hop_lengths": [512, 256, 256],
|
711 |
+
# "sampling_rate": 24000,
|
712 |
+
# "filters": 32,
|
713 |
+
# "max_filters": 1024,
|
714 |
+
# "filters_scale": 1,
|
715 |
+
# "dilations": [1, 2, 4],
|
716 |
+
# "in_channels": 1,
|
717 |
+
# "out_channels": 1,
|
718 |
+
# "n_octaves": [9, 9, 9],
|
719 |
+
# "bins_per_octaves": [24, 36, 48],
|
720 |
+
# }
|
721 |
+
|
722 |
+
class DiscriminatorCQT(nn.Module):
|
723 |
+
def __init__(self, cfg, hop_length, n_octaves, bins_per_octave):
|
724 |
+
super(DiscriminatorCQT, self).__init__()
|
725 |
+
self.cfg = cfg
|
726 |
+
|
727 |
+
self.filters = cfg.filters
|
728 |
+
self.max_filters = cfg.max_filters
|
729 |
+
self.filters_scale = cfg.filters_scale
|
730 |
+
self.kernel_size = (3, 9)
|
731 |
+
self.dilations = cfg.dilations
|
732 |
+
self.stride = (1, 2)
|
733 |
+
|
734 |
+
self.in_channels = cfg.in_channels
|
735 |
+
self.out_channels = cfg.out_channels
|
736 |
+
self.fs = cfg.sampling_rate
|
737 |
+
self.hop_length = hop_length
|
738 |
+
self.n_octaves = n_octaves
|
739 |
+
self.bins_per_octave = bins_per_octave
|
740 |
+
|
741 |
+
self.cqt_transform = features.cqt.CQT2010v2(
|
742 |
+
sr=self.fs * 2,
|
743 |
+
hop_length=self.hop_length,
|
744 |
+
n_bins=self.bins_per_octave * self.n_octaves,
|
745 |
+
bins_per_octave=self.bins_per_octave,
|
746 |
+
output_format="Complex",
|
747 |
+
pad_mode="constant",
|
748 |
+
)
|
749 |
+
|
750 |
+
self.conv_pres = nn.ModuleList()
|
751 |
+
for i in range(self.n_octaves):
|
752 |
+
self.conv_pres.append(
|
753 |
+
NormConv2d(
|
754 |
+
self.in_channels * 2,
|
755 |
+
self.in_channels * 2,
|
756 |
+
kernel_size=self.kernel_size,
|
757 |
+
padding=get_2d_padding(self.kernel_size),
|
758 |
+
)
|
759 |
+
)
|
760 |
+
|
761 |
+
self.convs = nn.ModuleList()
|
762 |
+
|
763 |
+
self.convs.append(
|
764 |
+
NormConv2d(
|
765 |
+
self.in_channels * 2,
|
766 |
+
self.filters,
|
767 |
+
kernel_size=self.kernel_size,
|
768 |
+
padding=get_2d_padding(self.kernel_size),
|
769 |
+
)
|
770 |
+
)
|
771 |
+
|
772 |
+
in_chs = min(self.filters_scale * self.filters, self.max_filters)
|
773 |
+
for i, dilation in enumerate(self.dilations):
|
774 |
+
out_chs = min(
|
775 |
+
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
|
776 |
+
)
|
777 |
+
self.convs.append(
|
778 |
+
NormConv2d(
|
779 |
+
in_chs,
|
780 |
+
out_chs,
|
781 |
+
kernel_size=self.kernel_size,
|
782 |
+
stride=self.stride,
|
783 |
+
dilation=(dilation, 1),
|
784 |
+
padding=get_2d_padding(self.kernel_size, (dilation, 1)),
|
785 |
+
norm="weight_norm",
|
786 |
+
)
|
787 |
+
)
|
788 |
+
in_chs = out_chs
|
789 |
+
out_chs = min(
|
790 |
+
(self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
|
791 |
+
self.max_filters,
|
792 |
+
)
|
793 |
+
self.convs.append(
|
794 |
+
NormConv2d(
|
795 |
+
in_chs,
|
796 |
+
out_chs,
|
797 |
+
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
798 |
+
padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
799 |
+
norm="weight_norm",
|
800 |
+
)
|
801 |
+
)
|
802 |
+
|
803 |
+
self.conv_post = NormConv2d(
|
804 |
+
out_chs,
|
805 |
+
self.out_channels,
|
806 |
+
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
807 |
+
padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
808 |
+
norm="weight_norm",
|
809 |
+
)
|
810 |
+
|
811 |
+
self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE)
|
812 |
+
self.resample = torchaudio.transforms.Resample(
|
813 |
+
orig_freq=self.fs, new_freq=self.fs * 2
|
814 |
+
)
|
815 |
+
|
816 |
+
def forward(self, x):
|
817 |
+
fmap = []
|
818 |
+
|
819 |
+
x = self.resample(x)
|
820 |
+
|
821 |
+
z = self.cqt_transform(x)
|
822 |
+
|
823 |
+
z_amplitude = z[:, :, :, 0].unsqueeze(1)
|
824 |
+
z_phase = z[:, :, :, 1].unsqueeze(1)
|
825 |
+
|
826 |
+
z = torch.cat([z_amplitude, z_phase], dim=1)
|
827 |
+
z = rearrange(z, "b c w t -> b c t w")
|
828 |
+
|
829 |
+
latent_z = []
|
830 |
+
for i in range(self.n_octaves):
|
831 |
+
latent_z.append(
|
832 |
+
self.conv_pres[i](
|
833 |
+
z[
|
834 |
+
:,
|
835 |
+
:,
|
836 |
+
:,
|
837 |
+
i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
|
838 |
+
]
|
839 |
+
)
|
840 |
+
)
|
841 |
+
latent_z = torch.cat(latent_z, dim=-1)
|
842 |
+
|
843 |
+
for i, l in enumerate(self.convs):
|
844 |
+
latent_z = l(latent_z)
|
845 |
+
|
846 |
+
latent_z = self.activation(latent_z)
|
847 |
+
fmap.append(latent_z)
|
848 |
+
|
849 |
+
latent_z = self.conv_post(latent_z)
|
850 |
+
|
851 |
+
return latent_z, fmap
|
852 |
+
|
853 |
+
|
854 |
+
|
855 |
+
class MultiScaleSubbandCQTDiscriminator(nn.Module): # replacing "MultiResSpecDiscriminator"
|
856 |
+
def __init__(self):
|
857 |
+
super(MultiScaleSubbandCQTDiscriminator, self).__init__()
|
858 |
+
cfg = Munch(multiscale_subband_cfg)
|
859 |
+
self.cfg = cfg
|
860 |
+
self.discriminators = nn.ModuleList(
|
861 |
+
[
|
862 |
+
DiscriminatorCQT(
|
863 |
+
cfg,
|
864 |
+
hop_length=cfg.hop_lengths[i],
|
865 |
+
n_octaves=cfg.n_octaves[i],
|
866 |
+
bins_per_octave=cfg.bins_per_octaves[i],
|
867 |
+
)
|
868 |
+
for i in range(len(cfg.hop_lengths))
|
869 |
+
]
|
870 |
+
)
|
871 |
+
|
872 |
+
def forward(self, y, y_hat):
|
873 |
+
y_d_rs = []
|
874 |
+
y_d_gs = []
|
875 |
+
fmap_rs = []
|
876 |
+
fmap_gs = []
|
877 |
+
|
878 |
+
for disc in self.discriminators:
|
879 |
+
y_d_r, fmap_r = disc(y)
|
880 |
+
y_d_g, fmap_g = disc(y_hat)
|
881 |
+
y_d_rs.append(y_d_r)
|
882 |
+
fmap_rs.append(fmap_r)
|
883 |
+
y_d_gs.append(y_d_g)
|
884 |
+
fmap_gs.append(fmap_g)
|
885 |
+
|
886 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
887 |
+
|
888 |
+
|
889 |
+
|
890 |
+
#############################
|
891 |
+
|
892 |
+
|
893 |
+
|
894 |
+
def feature_loss(fmap_r, fmap_g):
|
895 |
+
loss = 0
|
896 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
897 |
+
for rl, gl in zip(dr, dg):
|
898 |
+
loss += torch.mean(torch.abs(rl - gl))
|
899 |
+
|
900 |
+
return loss*2
|
901 |
+
|
902 |
+
|
903 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
904 |
+
loss = 0
|
905 |
+
r_losses = []
|
906 |
+
g_losses = []
|
907 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
908 |
+
r_loss = torch.mean((1-dr)**2)
|
909 |
+
g_loss = torch.mean(dg**2)
|
910 |
+
loss += (r_loss + g_loss)
|
911 |
+
r_losses.append(r_loss.item())
|
912 |
+
g_losses.append(g_loss.item())
|
913 |
+
|
914 |
+
return loss, r_losses, g_losses
|
915 |
+
|
916 |
+
|
917 |
+
def generator_loss(disc_outputs):
|
918 |
+
loss = 0
|
919 |
+
gen_losses = []
|
920 |
+
for dg in disc_outputs:
|
921 |
+
l = torch.mean((1-dg)**2)
|
922 |
+
gen_losses.append(l)
|
923 |
+
loss += l
|
924 |
+
|
925 |
+
return loss, gen_losses
|
926 |
+
|
927 |
+
def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
|
928 |
+
loss = 0
|
929 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
930 |
+
tau = 0.04
|
931 |
+
m_DG = torch.median((dr-dg))
|
932 |
+
L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
|
933 |
+
loss += tau - F.relu(tau - L_rel)
|
934 |
+
return loss
|
935 |
+
|
936 |
+
def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
|
937 |
+
loss = 0
|
938 |
+
for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
|
939 |
+
tau = 0.04
|
940 |
+
m_DG = torch.median((dr-dg))
|
941 |
+
L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
|
942 |
+
loss += tau - F.relu(tau - L_rel)
|
943 |
+
return loss
|
RingFormer/norm2d.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
#################### Norm2D for Discriminators ####################
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import einops
|
11 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
12 |
+
|
13 |
+
CONV_NORMALIZATIONS = frozenset(
|
14 |
+
[
|
15 |
+
"none",
|
16 |
+
"weight_norm",
|
17 |
+
"spectral_norm",
|
18 |
+
"time_layer_norm",
|
19 |
+
"layer_norm",
|
20 |
+
"time_group_norm",
|
21 |
+
]
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class ConvLayerNorm(nn.LayerNorm):
|
26 |
+
"""
|
27 |
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
28 |
+
before running the normalization and moves them back to original position right after.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, normalized_shape, **kwargs):
|
32 |
+
super().__init__(normalized_shape, **kwargs)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = einops.rearrange(x, "b ... t -> b t ...")
|
36 |
+
x = super().forward(x)
|
37 |
+
x = einops.rearrange(x, "b t ... -> b ... t")
|
38 |
+
return
|
39 |
+
|
40 |
+
|
41 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
|
42 |
+
assert norm in CONV_NORMALIZATIONS
|
43 |
+
if norm == "weight_norm":
|
44 |
+
return weight_norm(module)
|
45 |
+
elif norm == "spectral_norm":
|
46 |
+
return spectral_norm(module)
|
47 |
+
else:
|
48 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
49 |
+
# doesn't need reparametrization.
|
50 |
+
return module
|
51 |
+
|
52 |
+
|
53 |
+
def get_norm_module(
|
54 |
+
module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
|
55 |
+
) -> nn.Module:
|
56 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
57 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
58 |
+
"""
|
59 |
+
assert norm in CONV_NORMALIZATIONS
|
60 |
+
if norm == "layer_norm":
|
61 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
62 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
63 |
+
elif norm == "time_group_norm":
|
64 |
+
if causal:
|
65 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
66 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
67 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
68 |
+
else:
|
69 |
+
return nn.Identity()
|
70 |
+
|
71 |
+
|
72 |
+
class NormConv2d(nn.Module):
|
73 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
74 |
+
to provide a uniform interface across normalization approaches.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
*args,
|
80 |
+
norm: str = "none",
|
81 |
+
norm_kwargs={},
|
82 |
+
**kwargs,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
86 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
87 |
+
self.norm_type = norm
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
x = self.conv(x)
|
91 |
+
x = self.norm(x)
|
92 |
+
return x
|
RingFormer/requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
3 |
+
librosa
|
4 |
+
scipy==1.4.1
|
5 |
+
tensorboard==2.0
|
6 |
+
soundfile
|
7 |
+
matplotlib==3.1.3
|
8 |
+
nnaudio
|
9 |
+
ring-attention-pytorch
|
10 |
+
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/triton-nightly
|
RingFormer/stft.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BSD 3-Clause License
|
3 |
+
Copyright (c) 2017, Prem Seetharaman
|
4 |
+
All rights reserved.
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
modification, are permitted provided that the following conditions are met:
|
7 |
+
* Redistributions of source code must retain the above copyright notice,
|
8 |
+
this list of conditions and the following disclaimer.
|
9 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer in the
|
11 |
+
documentation and/or other materials provided with the distribution.
|
12 |
+
* Neither the name of the copyright holder nor the names of its
|
13 |
+
contributors may be used to endorse or promote products derived from this
|
14 |
+
software without specific prior written permission.
|
15 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
19 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
22 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
+
"""
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import numpy as np
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from torch.autograd import Variable
|
31 |
+
from scipy.signal import get_window
|
32 |
+
from librosa.util import pad_center, tiny
|
33 |
+
import librosa.util as librosa_util
|
34 |
+
|
35 |
+
|
36 |
+
def window_sumsquare(
|
37 |
+
window,
|
38 |
+
n_frames,
|
39 |
+
hop_length=200,
|
40 |
+
win_length=800,
|
41 |
+
n_fft=800,
|
42 |
+
dtype=np.float32,
|
43 |
+
norm=None,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
# from librosa 0.6
|
47 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
48 |
+
This is used to estimate modulation effects induced by windowing
|
49 |
+
observations in short-time fourier transforms.
|
50 |
+
Parameters
|
51 |
+
----------
|
52 |
+
window : string, tuple, number, callable, or list-like
|
53 |
+
Window specification, as in `get_window`
|
54 |
+
n_frames : int > 0
|
55 |
+
The number of analysis frames
|
56 |
+
hop_length : int > 0
|
57 |
+
The number of samples to advance between frames
|
58 |
+
win_length : [optional]
|
59 |
+
The length of the window function. By default, this matches `n_fft`.
|
60 |
+
n_fft : int > 0
|
61 |
+
The length of each analysis frame.
|
62 |
+
dtype : np.dtype
|
63 |
+
The data type of the output
|
64 |
+
Returns
|
65 |
+
-------
|
66 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
67 |
+
The sum-squared envelope of the window function
|
68 |
+
"""
|
69 |
+
if win_length is None:
|
70 |
+
win_length = n_fft
|
71 |
+
|
72 |
+
n = n_fft + hop_length * (n_frames - 1)
|
73 |
+
x = np.zeros(n, dtype=dtype)
|
74 |
+
|
75 |
+
# Compute the squared window at the desired length
|
76 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
77 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
|
78 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
79 |
+
|
80 |
+
# Fill the envelope
|
81 |
+
for i in range(n_frames):
|
82 |
+
sample = i * hop_length
|
83 |
+
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
def stft(x, fft_size, hop_size, win_length, window):
|
88 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
89 |
+
Args:
|
90 |
+
x (Tensor): Input signal tensor (B, T).
|
91 |
+
fft_size (int): FFT size.
|
92 |
+
hop_size (int): Hop size.
|
93 |
+
win_length (int): Window length.
|
94 |
+
window (str): Window function type.
|
95 |
+
Returns:
|
96 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
97 |
+
"""
|
98 |
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
|
99 |
+
real = x_stft[..., 0]
|
100 |
+
imag = x_stft[..., 1]
|
101 |
+
|
102 |
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
103 |
+
return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
|
104 |
+
|
105 |
+
|
106 |
+
class STFT(torch.nn.Module):
|
107 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self, filter_length=800, hop_length=200, win_length=800, window="hann"
|
111 |
+
):
|
112 |
+
super(STFT, self).__init__()
|
113 |
+
self.filter_length = filter_length
|
114 |
+
self.hop_length = hop_length
|
115 |
+
self.win_length = win_length
|
116 |
+
self.window = window
|
117 |
+
self.forward_transform = None
|
118 |
+
scale = self.filter_length / self.hop_length
|
119 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
120 |
+
|
121 |
+
cutoff = int((self.filter_length / 2 + 1))
|
122 |
+
fourier_basis = np.vstack(
|
123 |
+
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
124 |
+
)
|
125 |
+
|
126 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
127 |
+
inverse_basis = torch.FloatTensor(
|
128 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
129 |
+
)
|
130 |
+
|
131 |
+
if window is not None:
|
132 |
+
assert filter_length >= win_length
|
133 |
+
# get window and zero center pad it to filter_length
|
134 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
135 |
+
fft_window = pad_center(fft_window, filter_length)
|
136 |
+
fft_window = torch.from_numpy(fft_window).float()
|
137 |
+
|
138 |
+
# window the bases
|
139 |
+
forward_basis *= fft_window
|
140 |
+
inverse_basis *= fft_window
|
141 |
+
|
142 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
143 |
+
self.register_buffer("inverse_basis", inverse_basis.float())
|
144 |
+
|
145 |
+
def transform(self, input_data):
|
146 |
+
num_batches = input_data.size(0)
|
147 |
+
num_samples = input_data.size(1)
|
148 |
+
|
149 |
+
self.num_samples = num_samples
|
150 |
+
|
151 |
+
# similar to librosa, reflect-pad the input
|
152 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
153 |
+
input_data = F.pad(
|
154 |
+
input_data.unsqueeze(1),
|
155 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
156 |
+
mode="reflect",
|
157 |
+
)
|
158 |
+
input_data = input_data.squeeze(1)
|
159 |
+
|
160 |
+
forward_transform = F.conv1d(
|
161 |
+
input_data,
|
162 |
+
Variable(self.forward_basis, requires_grad=False),
|
163 |
+
stride=self.hop_length,
|
164 |
+
padding=0,
|
165 |
+
)
|
166 |
+
|
167 |
+
cutoff = int((self.filter_length / 2) + 1)
|
168 |
+
real_part = forward_transform[:, :cutoff, :]
|
169 |
+
imag_part = forward_transform[:, cutoff:, :]
|
170 |
+
|
171 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
172 |
+
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
|
173 |
+
|
174 |
+
return magnitude, phase
|
175 |
+
|
176 |
+
def inverse(self, magnitude, phase):
|
177 |
+
recombine_magnitude_phase = torch.cat(
|
178 |
+
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
179 |
+
)
|
180 |
+
|
181 |
+
inverse_transform = F.conv_transpose1d(
|
182 |
+
recombine_magnitude_phase,
|
183 |
+
Variable(self.inverse_basis, requires_grad=False),
|
184 |
+
stride=self.hop_length,
|
185 |
+
padding=0,
|
186 |
+
)
|
187 |
+
|
188 |
+
if self.window is not None:
|
189 |
+
window_sum = window_sumsquare(
|
190 |
+
self.window,
|
191 |
+
magnitude.size(-1),
|
192 |
+
hop_length=self.hop_length,
|
193 |
+
win_length=self.win_length,
|
194 |
+
n_fft=self.filter_length,
|
195 |
+
dtype=np.float32,
|
196 |
+
)
|
197 |
+
# remove modulation effects
|
198 |
+
approx_nonzero_indices = torch.from_numpy(
|
199 |
+
np.where(window_sum > tiny(window_sum))[0]
|
200 |
+
)
|
201 |
+
window_sum = torch.autograd.Variable(
|
202 |
+
torch.from_numpy(window_sum), requires_grad=False
|
203 |
+
)
|
204 |
+
window_sum = (
|
205 |
+
window_sum.to(inverse_transform.device())
|
206 |
+
if magnitude.is_cuda
|
207 |
+
else window_sum
|
208 |
+
)
|
209 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
210 |
+
approx_nonzero_indices
|
211 |
+
]
|
212 |
+
|
213 |
+
# scale by hop ratio
|
214 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
215 |
+
|
216 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
|
217 |
+
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
|
218 |
+
|
219 |
+
return inverse_transform
|
220 |
+
|
221 |
+
def forward(self, input_data):
|
222 |
+
self.magnitude, self.phase = self.transform(input_data)
|
223 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
224 |
+
return reconstruction
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
class TorchSTFT(torch.nn.Module):
|
229 |
+
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
230 |
+
super().__init__()
|
231 |
+
self.filter_length = filter_length
|
232 |
+
self.hop_length = hop_length
|
233 |
+
self.win_length = win_length
|
234 |
+
self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
|
235 |
+
|
236 |
+
def transform(self, input_data):
|
237 |
+
forward_transform = torch.stft(
|
238 |
+
input_data,
|
239 |
+
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
240 |
+
return_complex=True)
|
241 |
+
|
242 |
+
return torch.abs(forward_transform), torch.angle(forward_transform)
|
243 |
+
|
244 |
+
def inverse(self, magnitude, phase):
|
245 |
+
inverse_transform = torch.istft(
|
246 |
+
magnitude * torch.exp(phase * 1j),
|
247 |
+
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
248 |
+
|
249 |
+
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
250 |
+
|
251 |
+
def forward(self, input_data):
|
252 |
+
self.magnitude, self.phase = self.transform(input_data)
|
253 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
254 |
+
return reconstruction
|
RingFormer/train.py
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import warnings
|
2 |
+
# warnings.simplefilter(action='ignore', category=FutureWarning)
|
3 |
+
# import itertools
|
4 |
+
# import os
|
5 |
+
# import time
|
6 |
+
# import argparse
|
7 |
+
# import json
|
8 |
+
# import torch
|
9 |
+
# import torch.nn.functional as F
|
10 |
+
# from torch.utils.tensorboard import SummaryWriter
|
11 |
+
# from torch.utils.data import DistributedSampler, DataLoader
|
12 |
+
# import torch.multiprocessing as mp
|
13 |
+
# from torch.distributed import init_process_group
|
14 |
+
# from torch.nn.parallel import DistributedDataParallel
|
15 |
+
# from env import AttrDict, build_env
|
16 |
+
# from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
|
17 |
+
# from models import Generator, MultiPeriodDiscriminator, feature_loss, generator_loss,\
|
18 |
+
# discriminator_loss, discriminator_TPRLS_loss, generator_TPRLS_loss, MultiScaleSubbandCQTDiscriminator
|
19 |
+
# from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
|
20 |
+
# from stft import TorchSTFT
|
21 |
+
# from Utils.JDC.model import JDCNet
|
22 |
+
|
23 |
+
# torch.backends.cudnn.benchmark = True
|
24 |
+
|
25 |
+
|
26 |
+
# def train(rank, a, h):
|
27 |
+
# if h.num_gpus > 1:
|
28 |
+
# init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
|
29 |
+
# world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
|
30 |
+
|
31 |
+
# torch.cuda.manual_seed(h.seed)
|
32 |
+
# device = torch.device('cuda:{:d}'.format(rank))
|
33 |
+
|
34 |
+
# F0_model = JDCNet(num_class=1, seq_len=192)
|
35 |
+
# params = torch.load(h.F0_path)['net']
|
36 |
+
# F0_model.load_state_dict(params)
|
37 |
+
|
38 |
+
# generator = Generator(h, F0_model, device=device).to(device)
|
39 |
+
# mpd = MultiPeriodDiscriminator().to(device)
|
40 |
+
# msd = MultiScaleSubbandCQTDiscriminator().to(device)
|
41 |
+
# stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
|
42 |
+
|
43 |
+
# if rank == 0:
|
44 |
+
# print(generator)
|
45 |
+
# os.makedirs(a.checkpoint_path, exist_ok=True)
|
46 |
+
# print("checkpoints directory : ", a.checkpoint_path)
|
47 |
+
|
48 |
+
# if os.path.isdir(a.checkpoint_path):
|
49 |
+
# cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
|
50 |
+
# cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
|
51 |
+
|
52 |
+
# steps = 0
|
53 |
+
# if cp_g is None or cp_do is None:
|
54 |
+
# state_dict_do = None
|
55 |
+
# last_epoch = -1
|
56 |
+
# else:
|
57 |
+
# state_dict_g = load_checkpoint(cp_g, device)
|
58 |
+
# state_dict_do = load_checkpoint(cp_do, device)
|
59 |
+
# generator.load_state_dict(state_dict_g['generator'])
|
60 |
+
# mpd.load_state_dict(state_dict_do['mpd'])
|
61 |
+
# msd.load_state_dict(state_dict_do['msd'])
|
62 |
+
# steps = state_dict_do['steps'] + 1
|
63 |
+
# last_epoch = state_dict_do['epoch']
|
64 |
+
|
65 |
+
# if h.num_gpus > 1:
|
66 |
+
# generator = DistributedDataParallel(generator, device_ids=[rank], find_unused_parameters=True).to(device)
|
67 |
+
# mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
68 |
+
# msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
|
69 |
+
|
70 |
+
# optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
71 |
+
# optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
|
72 |
+
# h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
73 |
+
|
74 |
+
# if state_dict_do is not None:
|
75 |
+
# optim_g.load_state_dict(state_dict_do['optim_g'])
|
76 |
+
# optim_d.load_state_dict(state_dict_do['optim_d'])
|
77 |
+
|
78 |
+
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
79 |
+
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
80 |
+
|
81 |
+
# training_filelist, validation_filelist = get_dataset_filelist(a)
|
82 |
+
|
83 |
+
# trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
|
84 |
+
# h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
|
85 |
+
# shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
|
86 |
+
# fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
|
87 |
+
|
88 |
+
# train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
89 |
+
|
90 |
+
# train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
|
91 |
+
# sampler=train_sampler,
|
92 |
+
# batch_size=h.batch_size,
|
93 |
+
# pin_memory=True,
|
94 |
+
# drop_last=True)
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
# if rank == 0:
|
100 |
+
# validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
|
101 |
+
# h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
|
102 |
+
# fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
|
103 |
+
# base_mels_path=a.input_mels_dir)
|
104 |
+
# validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
|
105 |
+
# sampler=None,
|
106 |
+
# batch_size=1,
|
107 |
+
# pin_memory=True,
|
108 |
+
# drop_last=True)
|
109 |
+
|
110 |
+
# sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
|
111 |
+
|
112 |
+
# generator.train()
|
113 |
+
# mpd.train()
|
114 |
+
# msd.train()
|
115 |
+
# for epoch in range(max(0, last_epoch), a.training_epochs):
|
116 |
+
# if rank == 0:
|
117 |
+
# start = time.time()
|
118 |
+
# print("Epoch: {}".format(epoch+1))
|
119 |
+
|
120 |
+
# if h.num_gpus > 1:
|
121 |
+
# train_sampler.set_epoch(epoch)
|
122 |
+
|
123 |
+
# for i, batch in enumerate(train_loader):
|
124 |
+
# if rank == 0:
|
125 |
+
# start_b = time.time()
|
126 |
+
# x, y, _, y_mel = batch
|
127 |
+
# x = torch.autograd.Variable(x.to(device, non_blocking=True))
|
128 |
+
# y = torch.autograd.Variable(y.to(device, non_blocking=True))
|
129 |
+
# y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
130 |
+
# y = y.unsqueeze(1)
|
131 |
+
# # y_g_hat = generator(x)
|
132 |
+
# spec, phase = generator(x)
|
133 |
+
|
134 |
+
# y_g_hat = stft.inverse(spec, phase)
|
135 |
+
|
136 |
+
# y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
|
137 |
+
# h.fmin, h.fmax_for_loss)
|
138 |
+
|
139 |
+
# optim_d.zero_grad()
|
140 |
+
|
141 |
+
# # MPD
|
142 |
+
# y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
143 |
+
# loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
144 |
+
# loss_disc_f += discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
|
145 |
+
|
146 |
+
# # MSD
|
147 |
+
# y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
148 |
+
# loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
149 |
+
# loss_disc_s += discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
150 |
+
|
151 |
+
# loss_disc_all = loss_disc_s + loss_disc_f
|
152 |
+
|
153 |
+
# loss_disc_all.backward()
|
154 |
+
# optim_d.step()
|
155 |
+
|
156 |
+
# # Generator
|
157 |
+
# optim_g.zero_grad()
|
158 |
+
|
159 |
+
# # L1 Mel-Spectrogram Loss
|
160 |
+
# loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
161 |
+
|
162 |
+
# y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
163 |
+
# y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
164 |
+
# loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
165 |
+
# loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
166 |
+
# loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
167 |
+
# loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
168 |
+
|
169 |
+
# loss_gen_f += generator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
|
170 |
+
# loss_gen_s += generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
171 |
+
|
172 |
+
# loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
173 |
+
|
174 |
+
# loss_gen_all.backward()
|
175 |
+
# optim_g.step()
|
176 |
+
|
177 |
+
# if rank == 0:
|
178 |
+
# # STDOUT logging
|
179 |
+
# if steps % a.stdout_interval == 0:
|
180 |
+
# with torch.no_grad():
|
181 |
+
# mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
|
182 |
+
|
183 |
+
# print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
|
184 |
+
# format(steps, loss_gen_all, mel_error, time.time() - start_b))
|
185 |
+
|
186 |
+
# # checkpointing
|
187 |
+
# if steps % a.checkpoint_interval == 0 and steps != 0:
|
188 |
+
# checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
|
189 |
+
# save_checkpoint(checkpoint_path,
|
190 |
+
# {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
191 |
+
# checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
|
192 |
+
# save_checkpoint(checkpoint_path,
|
193 |
+
# {'mpd': (mpd.module if h.num_gpus > 1
|
194 |
+
# else mpd).state_dict(),
|
195 |
+
# 'msd': (msd.module if h.num_gpus > 1
|
196 |
+
# else msd).state_dict(),
|
197 |
+
# 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
198 |
+
# 'epoch': epoch})
|
199 |
+
|
200 |
+
# # Tensorboard summary logging
|
201 |
+
# if steps % a.summary_interval == 0:
|
202 |
+
# sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
203 |
+
# sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
204 |
+
|
205 |
+
# # Validation
|
206 |
+
# if steps % a.validation_interval == 0: # and steps != 0:
|
207 |
+
# generator.eval()
|
208 |
+
# torch.cuda.empty_cache()
|
209 |
+
# val_err_tot = 0
|
210 |
+
# with torch.no_grad():
|
211 |
+
# for j, batch in enumerate(validation_loader):
|
212 |
+
# x, y, _, y_mel = batch
|
213 |
+
# # y_g_hat = generator(x.to(device))
|
214 |
+
# spec, phase = generator(x.to(device))
|
215 |
+
|
216 |
+
# y_g_hat = stft.inverse(spec, phase)
|
217 |
+
|
218 |
+
# y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
219 |
+
# y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
|
220 |
+
# h.hop_size, h.win_size,
|
221 |
+
# h.fmin, h.fmax_for_loss)
|
222 |
+
# val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
|
223 |
+
|
224 |
+
# if j <= 4:
|
225 |
+
# if steps == 0:
|
226 |
+
# sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
|
227 |
+
# sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
|
228 |
+
|
229 |
+
# sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
|
230 |
+
# y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
|
231 |
+
# h.sampling_rate, h.hop_size, h.win_size,
|
232 |
+
# h.fmin, h.fmax)
|
233 |
+
# sw.add_figure('generated/y_hat_spec_{}'.format(j),
|
234 |
+
# plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
|
235 |
+
|
236 |
+
# val_err = val_err_tot / (j+1)
|
237 |
+
# sw.add_scalar("validation/mel_spec_error", val_err, steps)
|
238 |
+
|
239 |
+
# generator.train()
|
240 |
+
|
241 |
+
# steps += 1
|
242 |
+
|
243 |
+
# scheduler_g.step()
|
244 |
+
# scheduler_d.step()
|
245 |
+
|
246 |
+
# if rank == 0:
|
247 |
+
# print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
|
248 |
+
|
249 |
+
|
250 |
+
# def main():
|
251 |
+
# print('Initializing Training Process..')
|
252 |
+
|
253 |
+
# parser = argparse.ArgumentParser()
|
254 |
+
|
255 |
+
# parser.add_argument('--group_name', default=None)
|
256 |
+
# parser.add_argument('--input_wavs_dir', default='')
|
257 |
+
# parser.add_argument('--input_mels_dir', default='ft_dataset')
|
258 |
+
# # parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/training.txt')
|
259 |
+
# parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/validation.txt')
|
260 |
+
# parser.add_argument('--input_validation_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/validation.txt')
|
261 |
+
# parser.add_argument('--checkpoint_path', default='cp_ringformer_24')
|
262 |
+
# parser.add_argument('--config', default='config_v1.json')
|
263 |
+
# parser.add_argument('--training_epochs', default=3100, type=int)
|
264 |
+
# parser.add_argument('--stdout_interval', default=5, type=int)
|
265 |
+
# parser.add_argument('--checkpoint_interval', default=5000, type=int)
|
266 |
+
# parser.add_argument('--summary_interval', default=100, type=int)
|
267 |
+
# parser.add_argument('--validation_interval', default=1000, type=int)
|
268 |
+
# parser.add_argument('--fine_tuning', default=False, type=bool)
|
269 |
+
|
270 |
+
# a = parser.parse_args()
|
271 |
+
|
272 |
+
# with open(a.config) as f:
|
273 |
+
# data = f.read()
|
274 |
+
|
275 |
+
# json_config = json.loads(data)
|
276 |
+
# h = AttrDict(json_config)
|
277 |
+
# build_env(a.config, 'config.json', a.checkpoint_path)
|
278 |
+
|
279 |
+
# torch.manual_seed(h.seed)
|
280 |
+
# if torch.cuda.is_available():
|
281 |
+
# torch.cuda.manual_seed(h.seed)
|
282 |
+
# h.num_gpus = torch.cuda.device_count()
|
283 |
+
# h.batch_size = int(h.batch_size / h.num_gpus)
|
284 |
+
# print('Batch size per GPU :', h.batch_size)
|
285 |
+
# else:
|
286 |
+
# pass
|
287 |
+
|
288 |
+
# if h.num_gpus > 1:
|
289 |
+
# mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
|
290 |
+
# else:
|
291 |
+
# train(0, a, h)
|
292 |
+
|
293 |
+
|
294 |
+
# if __name__ == '__main__':
|
295 |
+
# main()
|
296 |
+
|
297 |
+
|
298 |
+
import warnings
|
299 |
+
warnings.simplefilter(action='ignore', category=FutureWarning)
|
300 |
+
import itertools
|
301 |
+
import os
|
302 |
+
import time
|
303 |
+
import argparse
|
304 |
+
import json
|
305 |
+
import torch
|
306 |
+
import torch.nn.functional as F
|
307 |
+
from torch.utils.tensorboard import SummaryWriter
|
308 |
+
from torch.utils.data import DistributedSampler, DataLoader
|
309 |
+
import torch.multiprocessing as mp
|
310 |
+
from torch.distributed import init_process_group
|
311 |
+
from torch.nn.parallel import DistributedDataParallel
|
312 |
+
from env import AttrDict, build_env
|
313 |
+
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
|
314 |
+
from models import Generator, MultiPeriodDiscriminator, feature_loss, generator_loss,\
|
315 |
+
discriminator_loss, discriminator_TPRLS_loss, generator_TPRLS_loss, MultiScaleSubbandCQTDiscriminator
|
316 |
+
from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
|
317 |
+
from stft import TorchSTFT
|
318 |
+
from Utils.JDC.model import JDCNet
|
319 |
+
from accelerate import Accelerator
|
320 |
+
from accelerate.utils import LoggerType
|
321 |
+
from accelerate import DistributedDataParallelKwargs
|
322 |
+
|
323 |
+
|
324 |
+
torch.backends.cudnn.benchmark = True
|
325 |
+
|
326 |
+
|
327 |
+
def train(accelerator, a, h):
|
328 |
+
# if h.num_gpus > 1:
|
329 |
+
# init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
|
330 |
+
# world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
|
331 |
+
|
332 |
+
|
333 |
+
|
334 |
+
torch.cuda.manual_seed(h.seed)
|
335 |
+
device = accelerator.device
|
336 |
+
|
337 |
+
F0_model = JDCNet(num_class=1, seq_len=192)
|
338 |
+
params = torch.load(h.F0_path)['model']
|
339 |
+
F0_model.load_state_dict(params)
|
340 |
+
|
341 |
+
generator = Generator(h, F0_model).to(device)
|
342 |
+
mpd = MultiPeriodDiscriminator().to(device)
|
343 |
+
msd = MultiScaleSubbandCQTDiscriminator().to(device)
|
344 |
+
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
|
345 |
+
|
346 |
+
with accelerator.main_process_first():
|
347 |
+
accelerator.print(generator)
|
348 |
+
os.makedirs(a.checkpoint_path, exist_ok=True)
|
349 |
+
accelerator.print("checkpoints directory : ", a.checkpoint_path)
|
350 |
+
|
351 |
+
if os.path.isdir(a.checkpoint_path):
|
352 |
+
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
|
353 |
+
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
|
354 |
+
|
355 |
+
steps = 0
|
356 |
+
if cp_g is None or cp_do is None:
|
357 |
+
state_dict_do = None
|
358 |
+
last_epoch = -1
|
359 |
+
else:
|
360 |
+
state_dict_g = load_checkpoint(cp_g, device)
|
361 |
+
state_dict_do = load_checkpoint(cp_do, device)
|
362 |
+
generator.load_state_dict(state_dict_g['generator'])
|
363 |
+
mpd.load_state_dict(state_dict_do['mpd'])
|
364 |
+
msd.load_state_dict(state_dict_do['msd'])
|
365 |
+
steps = state_dict_do['steps'] + 1
|
366 |
+
last_epoch = state_dict_do['epoch']
|
367 |
+
|
368 |
+
# if h.num_gpus > 1:
|
369 |
+
generator = accelerator.prepare(generator).to(device)
|
370 |
+
mpd = accelerator.prepare(mpd).to(device)
|
371 |
+
msd = accelerator.prepare(msd).to(device)
|
372 |
+
|
373 |
+
optim_g = accelerator.prepare(torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]))
|
374 |
+
optim_d = accelerator.prepare(torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
|
375 |
+
h.learning_rate, betas=[h.adam_b1, h.adam_b2]))
|
376 |
+
|
377 |
+
|
378 |
+
|
379 |
+
if state_dict_do is not None:
|
380 |
+
optim_g.load_state_dict(state_dict_do['optim_g'])
|
381 |
+
optim_d.load_state_dict(state_dict_do['optim_d'])
|
382 |
+
|
383 |
+
scheduler_g = accelerator.prepare(torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch))
|
384 |
+
scheduler_d = accelerator.prepare(torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch))
|
385 |
+
|
386 |
+
training_filelist, validation_filelist = get_dataset_filelist(a)
|
387 |
+
|
388 |
+
trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
|
389 |
+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
|
390 |
+
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
|
391 |
+
fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
|
392 |
+
|
393 |
+
# train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
394 |
+
|
395 |
+
accelerator.print("bs", h.batch_size)
|
396 |
+
|
397 |
+
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
|
398 |
+
# sampler=train_sampler,
|
399 |
+
batch_size=h.batch_size,
|
400 |
+
pin_memory=True,
|
401 |
+
drop_last=True)
|
402 |
+
|
403 |
+
|
404 |
+
train_loader = accelerator.prepare(train_loader)
|
405 |
+
|
406 |
+
|
407 |
+
|
408 |
+
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
|
409 |
+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
|
410 |
+
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
|
411 |
+
base_mels_path=a.input_mels_dir)
|
412 |
+
|
413 |
+
|
414 |
+
|
415 |
+
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
|
416 |
+
sampler=None,
|
417 |
+
batch_size=1,
|
418 |
+
pin_memory=True,
|
419 |
+
drop_last=True)
|
420 |
+
|
421 |
+
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
|
422 |
+
|
423 |
+
# validation_loader = accelerator.prepare(validation_loader)
|
424 |
+
|
425 |
+
generator.train()
|
426 |
+
mpd.train()
|
427 |
+
msd.train()
|
428 |
+
for epoch in range(max(0, last_epoch), a.training_epochs):
|
429 |
+
with accelerator.main_process_first():
|
430 |
+
start = time.time()
|
431 |
+
accelerator.print("Epoch: {}".format(epoch+1))
|
432 |
+
|
433 |
+
# if h.num_gpus > 1:
|
434 |
+
# train_sampler.set_epoch(epoch)
|
435 |
+
|
436 |
+
for i, batch in enumerate(train_loader):
|
437 |
+
# with accelerator.main_process_first():
|
438 |
+
start_b = time.time()
|
439 |
+
x, y, _, y_mel = batch
|
440 |
+
x = torch.autograd.Variable(x.to(device, non_blocking=True))
|
441 |
+
y = torch.autograd.Variable(y.to(device, non_blocking=True))
|
442 |
+
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
443 |
+
y = y.unsqueeze(1)
|
444 |
+
# y_g_hat = generator(x)
|
445 |
+
spec, phase = generator(x)
|
446 |
+
|
447 |
+
y_g_hat = stft.inverse(spec, phase)
|
448 |
+
|
449 |
+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
|
450 |
+
h.fmin, h.fmax_for_loss)
|
451 |
+
|
452 |
+
optim_d.zero_grad()
|
453 |
+
|
454 |
+
# MPD
|
455 |
+
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
456 |
+
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
457 |
+
loss_disc_f += discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
|
458 |
+
|
459 |
+
# MSD
|
460 |
+
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
461 |
+
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
462 |
+
loss_disc_s += discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
463 |
+
|
464 |
+
loss_disc_all = loss_disc_s + loss_disc_f
|
465 |
+
|
466 |
+
accelerator.backward(loss_disc_all)
|
467 |
+
# accelerator.clip_grad_norm_(mpd.parameters(), max_norm=10.0)
|
468 |
+
# accelerator.clip_grad_norm_(msd.parameters(), max_norm=10.0)
|
469 |
+
|
470 |
+
# loss_disc_all.backward()
|
471 |
+
optim_d.step()
|
472 |
+
|
473 |
+
# Generator
|
474 |
+
optim_g.zero_grad()
|
475 |
+
|
476 |
+
# L1 Mel-Spectrogram Loss
|
477 |
+
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
478 |
+
|
479 |
+
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
480 |
+
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
481 |
+
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
482 |
+
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
483 |
+
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
484 |
+
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
485 |
+
|
486 |
+
loss_gen_f += generator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
|
487 |
+
loss_gen_s += generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
488 |
+
|
489 |
+
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
490 |
+
|
491 |
+
|
492 |
+
accelerator.backward(loss_gen_all)
|
493 |
+
|
494 |
+
# accelerator.clip_grad_norm_(generator.parameters(), max_norm=10.0)
|
495 |
+
|
496 |
+
# loss_gen_all.backward()
|
497 |
+
optim_g.step()
|
498 |
+
|
499 |
+
# accelerator.print('done +',i)
|
500 |
+
|
501 |
+
|
502 |
+
# STDOUT logging
|
503 |
+
if steps % a.stdout_interval == 0:
|
504 |
+
|
505 |
+
with accelerator.main_process_first():
|
506 |
+
|
507 |
+
with torch.no_grad():
|
508 |
+
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
|
509 |
+
|
510 |
+
accelerator.print('Steps : {:d}, Gen Loss Total : {:4.3f}, Disc Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
|
511 |
+
format(steps, loss_gen_all, loss_disc_all, mel_error, time.time() - start_b))
|
512 |
+
|
513 |
+
# checkpointing
|
514 |
+
# In the main code:
|
515 |
+
# if steps % a.checkpoint_interval == 0 and steps != 0:
|
516 |
+
# # with accelerator.main_process_first():
|
517 |
+
# # Unwrap the generator
|
518 |
+
# unwrapped_generator = accelerator.unwrap_model(generator)
|
519 |
+
# checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
|
520 |
+
# save_checkpoint(checkpoint_path,
|
521 |
+
# {'generator': unwrapped_generator.state_dict()})
|
522 |
+
|
523 |
+
# # Unwrap discriminators
|
524 |
+
# unwrapped_mpd = accelerator.unwrap_model(mpd)
|
525 |
+
# unwrapped_msd = accelerator.unwrap_model(msd)
|
526 |
+
|
527 |
+
# checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
|
528 |
+
# save_checkpoint(checkpoint_path,
|
529 |
+
# {'mpd': unwrapped_mpd.state_dict(),
|
530 |
+
# 'msd': unwrapped_msd.state_dict(),
|
531 |
+
# 'optim_g': optim_g.state_dict(),
|
532 |
+
# 'optim_d': optim_d.state_dict(),
|
533 |
+
# 'steps': steps,
|
534 |
+
# 'epoch': epoch})
|
535 |
+
|
536 |
+
|
537 |
+
|
538 |
+
|
539 |
+
if steps % a.checkpoint_interval == 0 and steps != 0:
|
540 |
+
|
541 |
+
with accelerator.main_process_first():
|
542 |
+
checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
|
543 |
+
save_checkpoint(checkpoint_path,
|
544 |
+
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
545 |
+
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
|
546 |
+
save_checkpoint(checkpoint_path,
|
547 |
+
{'mpd': (mpd.module if h.num_gpus > 1
|
548 |
+
else mpd).state_dict(),
|
549 |
+
'msd': (msd.module if h.num_gpus > 1
|
550 |
+
else msd).state_dict(),
|
551 |
+
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
552 |
+
'epoch': epoch})
|
553 |
+
|
554 |
+
# Tensorboard summary logging
|
555 |
+
if steps % a.summary_interval == 0:
|
556 |
+
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
557 |
+
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
558 |
+
|
559 |
+
# Validation
|
560 |
+
if steps % a.validation_interval == 0: # and steps != 0:
|
561 |
+
|
562 |
+
# with accelerator.main_process_first():
|
563 |
+
accelerator.print('evalution...')
|
564 |
+
|
565 |
+
|
566 |
+
|
567 |
+
generator.eval()
|
568 |
+
torch.cuda.empty_cache()
|
569 |
+
val_err_tot = 0
|
570 |
+
with torch.no_grad():
|
571 |
+
for j, batch in enumerate(validation_loader):
|
572 |
+
x, y, _, y_mel = batch
|
573 |
+
# y_g_hat = generator(x.to(device))
|
574 |
+
spec, phase = generator(x.to('cuda'))
|
575 |
+
|
576 |
+
y_g_hat = stft.inverse(spec, phase)
|
577 |
+
|
578 |
+
|
579 |
+
|
580 |
+
y_mel = torch.autograd.Variable(y_mel.to('cuda', non_blocking=True))
|
581 |
+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
|
582 |
+
h.hop_size, h.win_size,
|
583 |
+
h.fmin, h.fmax_for_loss)
|
584 |
+
|
585 |
+
accelerator.print('done: ',j)
|
586 |
+
|
587 |
+
|
588 |
+
val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
|
589 |
+
|
590 |
+
if j <= 4:
|
591 |
+
if steps == 0:
|
592 |
+
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
|
593 |
+
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
|
594 |
+
|
595 |
+
|
596 |
+
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
|
597 |
+
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
|
598 |
+
h.sampling_rate, h.hop_size, h.win_size,
|
599 |
+
h.fmin, h.fmax)
|
600 |
+
|
601 |
+
|
602 |
+
sw.add_figure('generated/y_hat_spec_{}'.format(j),
|
603 |
+
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
|
604 |
+
|
605 |
+
|
606 |
+
|
607 |
+
val_err = val_err_tot / (j+1)
|
608 |
+
sw.add_scalar("validation/mel_spec_error", val_err, steps)
|
609 |
+
|
610 |
+
|
611 |
+
generator.train()
|
612 |
+
|
613 |
+
steps += 1
|
614 |
+
|
615 |
+
scheduler_g.step()
|
616 |
+
scheduler_d.step()
|
617 |
+
|
618 |
+
# with accelerator.main_process_first():
|
619 |
+
accelerator.print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
|
620 |
+
|
621 |
+
|
622 |
+
def main():
|
623 |
+
|
624 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
625 |
+
|
626 |
+
|
627 |
+
|
628 |
+
parser = argparse.ArgumentParser()
|
629 |
+
|
630 |
+
parser.add_argument('--group_name', default=None)
|
631 |
+
parser.add_argument('--input_wavs_dir', default='')
|
632 |
+
parser.add_argument('--input_mels_dir', default='ft_dataset')
|
633 |
+
# parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/training.txt')
|
634 |
+
parser.add_argument('--input_training_file', default='/home/ubuntu/respair/audio_files_over_2sec_SUB.txt')
|
635 |
+
parser.add_argument('--input_validation_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/validation.txt')
|
636 |
+
parser.add_argument('--checkpoint_path', default='cp_ringformer_44.1khz_NUM2')
|
637 |
+
parser.add_argument('--config', default='config_v1.json')
|
638 |
+
parser.add_argument('--training_epochs', default=3100, type=int)
|
639 |
+
parser.add_argument('--stdout_interval', default=10, type=int)
|
640 |
+
parser.add_argument('--checkpoint_interval', default=1000, type=int)
|
641 |
+
parser.add_argument('--summary_interval', default=100, type=int)
|
642 |
+
parser.add_argument('--validation_interval', default=1000, type=int)
|
643 |
+
parser.add_argument('--fine_tuning', default=False, type=bool)
|
644 |
+
|
645 |
+
a = parser.parse_args()
|
646 |
+
|
647 |
+
accelerator = Accelerator(project_dir=a.checkpoint_path, split_batches=False,
|
648 |
+
kwargs_handlers=[ddp_kwargs])
|
649 |
+
|
650 |
+
accelerator.print('Initializing Training Process..')
|
651 |
+
|
652 |
+
with open(a.config) as f:
|
653 |
+
data = f.read()
|
654 |
+
|
655 |
+
json_config = json.loads(data)
|
656 |
+
h = AttrDict(json_config)
|
657 |
+
build_env(a.config, 'config.json', a.checkpoint_path)
|
658 |
+
|
659 |
+
torch.manual_seed(h.seed)
|
660 |
+
if torch.cuda.is_available():
|
661 |
+
torch.cuda.manual_seed(h.seed)
|
662 |
+
h.num_gpus = torch.cuda.device_count()
|
663 |
+
# h.batch_size = int(h.batch_size / h.num_gpus)
|
664 |
+
# print('Batch size per GPU :', h.batch_size)
|
665 |
+
else:
|
666 |
+
pass
|
667 |
+
# Pass accelerator to train function instead of rank
|
668 |
+
train(accelerator, a, h)
|
669 |
+
|
670 |
+
if __name__ == '__main__':
|
671 |
+
main()
|
RingFormer/utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import matplotlib
|
4 |
+
import torch
|
5 |
+
from torch.nn.utils import weight_norm
|
6 |
+
matplotlib.use("Agg")
|
7 |
+
import matplotlib.pylab as plt
|
8 |
+
|
9 |
+
|
10 |
+
def plot_spectrogram(spectrogram):
|
11 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
12 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
13 |
+
interpolation='none')
|
14 |
+
plt.colorbar(im, ax=ax)
|
15 |
+
|
16 |
+
fig.canvas.draw()
|
17 |
+
plt.close()
|
18 |
+
|
19 |
+
return fig
|
20 |
+
|
21 |
+
|
22 |
+
def init_weights(m, mean=0.0, std=0.01):
|
23 |
+
classname = m.__class__.__name__
|
24 |
+
if classname.find("Conv") != -1:
|
25 |
+
m.weight.data.normal_(mean, std)
|
26 |
+
|
27 |
+
|
28 |
+
def apply_weight_norm(m):
|
29 |
+
classname = m.__class__.__name__
|
30 |
+
if classname.find("Conv") != -1:
|
31 |
+
weight_norm(m)
|
32 |
+
|
33 |
+
|
34 |
+
def get_padding(kernel_size, dilation=1):
|
35 |
+
return int((kernel_size * dilation - dilation) / 2)
|
36 |
+
|
37 |
+
|
38 |
+
def load_checkpoint(filepath, device):
|
39 |
+
assert os.path.isfile(filepath)
|
40 |
+
print("Loading '{}'".format(filepath))
|
41 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
42 |
+
print("Complete.")
|
43 |
+
return checkpoint_dict
|
44 |
+
|
45 |
+
|
46 |
+
def save_checkpoint(filepath, obj):
|
47 |
+
print("Saving checkpoint to {}".format(filepath))
|
48 |
+
torch.save(obj, filepath)
|
49 |
+
print("Complete.")
|
50 |
+
|
51 |
+
|
52 |
+
def scan_checkpoint(cp_dir, prefix):
|
53 |
+
pattern = os.path.join(cp_dir, prefix + '????????')
|
54 |
+
cp_list = glob.glob(pattern)
|
55 |
+
if len(cp_list) == 0:
|
56 |
+
return None
|
57 |
+
return sorted(cp_list)[-1]
|
58 |
+
|