Respair commited on
Commit
eb29d0a
·
verified ·
1 Parent(s): 6313a2f

Upload folder using huggingface_hub

Browse files
Files changed (35) hide show
  1. .gitattributes +2 -0
  2. RingFormer/LICENSE +21 -0
  3. RingFormer/README.md +38 -0
  4. RingFormer/Utils/JDC/__init__.py +1 -0
  5. RingFormer/Utils/JDC/__pycache__/__init__.cpython-311.pyc +0 -0
  6. RingFormer/Utils/JDC/__pycache__/__init__.cpython-38.pyc +0 -0
  7. RingFormer/Utils/JDC/__pycache__/__init__.cpython-39.pyc +0 -0
  8. RingFormer/Utils/JDC/__pycache__/bst.t7 +3 -0
  9. RingFormer/Utils/JDC/__pycache__/model.cpython-311.pyc +0 -0
  10. RingFormer/Utils/JDC/__pycache__/model.cpython-38.pyc +0 -0
  11. RingFormer/Utils/JDC/__pycache__/model.cpython-39.pyc +0 -0
  12. RingFormer/Utils/JDC/bst.t7 +3 -0
  13. RingFormer/Utils/JDC/model.py +192 -0
  14. RingFormer/Utils/__init__.py +1 -0
  15. RingFormer/Utils/__pycache__/__init__.cpython-311.pyc +0 -0
  16. RingFormer/Utils/__pycache__/__init__.cpython-38.pyc +0 -0
  17. RingFormer/Utils/__pycache__/__init__.cpython-39.pyc +0 -0
  18. RingFormer/__pycache__/conformer.cpython-311.pyc +0 -0
  19. RingFormer/__pycache__/env.cpython-311.pyc +0 -0
  20. RingFormer/__pycache__/meldataset.cpython-311.pyc +0 -0
  21. RingFormer/__pycache__/models.cpython-311.pyc +0 -0
  22. RingFormer/__pycache__/norm2d.cpython-311.pyc +0 -0
  23. RingFormer/__pycache__/stft.cpython-311.pyc +0 -0
  24. RingFormer/__pycache__/utils.cpython-311.pyc +0 -0
  25. RingFormer/config_v1.json +42 -0
  26. RingFormer/conformer.py +228 -0
  27. RingFormer/env.py +15 -0
  28. RingFormer/inference.ipynb +292 -0
  29. RingFormer/meldataset.py +203 -0
  30. RingFormer/models.py +943 -0
  31. RingFormer/norm2d.py +92 -0
  32. RingFormer/requirements.txt +10 -0
  33. RingFormer/stft.py +254 -0
  34. RingFormer/train.py +671 -0
  35. RingFormer/utils.py +58 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ RingFormer/Utils/JDC/__pycache__/bst.t7 filter=lfs diff=lfs merge=lfs -text
37
+ RingFormer/Utils/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
RingFormer/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Aaron (Yinghao) Li
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
RingFormer/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiFTNet: A Fast High-Quality Neural Vocoder with Harmonic-plus-Noise Filter and Inverse Short Time Fourier Transform
2
+
3
+ ### Yinghao Aaron Li, Cong Han, Xilin Jiang, Nima Mesgarani
4
+
5
+ > Recent advancements in speech synthesis have leveraged GAN-based networks like HiFi-GAN and BigVGAN to produce high-fidelity waveforms from mel-spectrograms. However, these networks are computationally expensive and parameter-heavy. iSTFTNet addresses these limitations by integrating inverse short-time Fourier transform (iSTFT) into the network, achieving both speed and parameter efficiency. In this paper, we introduce an extension to iSTFTNet, termed HiFTNet, which incorporates a harmonic-plus-noise source filter in the time-frequency domain that uses a sinusoidal source from the fundamental frequency (F0) inferred via a pre-trained F0 estimation network for fast inference speed. Subjective evaluations on LJSpeech show that our model significantly outperforms both iSTFTNet and HiFi-GAN, achieving ground-truth-level performance. HiFTNet also outperforms BigVGAN-base on LibriTTS for unseen speakers and achieves comparable performance to BigVGAN while being four times faster with only 1/6 of the parameters. Our work sets a new benchmark for efficient, high-quality neural vocoding, paving the way for real-time applications that demand high quality speech synthesis.
6
+
7
+ Paper: [https://arxiv.org/abs/2309.09493](https://arxiv.org/abs/2309.09493)
8
+
9
+ Audio samples: [https://hiftnet.github.io/](https://hiftnet.github.io/)
10
+
11
+ **Check our TTS work that uses HiFTNet as speech decoder for human-level speech synthesis here: https://github.com/yl4579/StyleTTS2**
12
+
13
+ ## Pre-requisites
14
+ 1. Python >= 3.7
15
+ 2. Clone this repository:
16
+ ```bash
17
+ git clone https://github.com/yl4579/HiFTNet.git
18
+ cd HiFTNet
19
+ ```
20
+ 3. Install python requirements:
21
+ ```bash
22
+ pip install -r requirements.txt
23
+ ```
24
+
25
+ ## Training
26
+ ```bash
27
+ python train.py --config config_v1.json --[args]
28
+ ```
29
+ For the F0 model training, please refer to [yl4579/PitchExtractor](https://github.com/yl4579/PitchExtractor). This repo includes a pre-trained F0 model on LibriTTS. Still, you may want to train your own F0 model for the best performance, particularly for noisy or non-speech data, as we found that F0 estimation accuracy is essential for the vocoder performance.
30
+
31
+ ## Inference
32
+ Please refer to the notebook [inference.ipynb](https://github.com/yl4579/HiFTNet/blob/main/inference.ipynb) for details.
33
+ ### Pre-Trained Models
34
+ You can download the pre-trained LJSpeech model [here](https://huggingface.co/yl4579/HiFTNet/blob/main/LJSpeech/cp_hifigan.zip) and the pre-trained LibriTTS model [here](https://huggingface.co/yl4579/HiFTNet/blob/main/LibriTTS/cp_hifigan.zip). The pre-trained models contain parameters of the optimizers and discriminators that can be used for fine-tuning.
35
+
36
+ ## References
37
+ - [rishikksh20/iSTFTNet-pytorch](https://github.com/rishikksh20/iSTFTNet-pytorch)
38
+ - [nii-yamagishilab/project-NN-Pytorch-scripts/project/01-nsf](https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/tree/master/project/01-nsf)
RingFormer/Utils/JDC/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
RingFormer/Utils/JDC/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (154 Bytes). View file
 
RingFormer/Utils/JDC/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (129 Bytes). View file
 
RingFormer/Utils/JDC/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (129 Bytes). View file
 
RingFormer/Utils/JDC/__pycache__/bst.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a7939fda04b8ef20365a6e6132e8190addd1ffdfb4c3f7a6892d0d7d6fe70d8
3
+ size 63049083
RingFormer/Utils/JDC/__pycache__/model.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
RingFormer/Utils/JDC/__pycache__/model.cpython-38.pyc ADDED
Binary file (4.73 kB). View file
 
RingFormer/Utils/JDC/__pycache__/model.cpython-39.pyc ADDED
Binary file (4.72 kB). View file
 
RingFormer/Utils/JDC/bst.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
3
+ size 21029926
RingFormer/Utils/JDC/model.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of model from:
3
+ Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
4
+ Convolutional Recurrent Neural Networks" (2019)
5
+ Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
6
+ """
7
+ import torch
8
+ from torch import nn
9
+
10
+ class JDCNet(nn.Module):
11
+ """
12
+ Joint Detection and Classification Network model for singing voice melody.
13
+ """
14
+
15
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
16
+ super().__init__()
17
+ self.num_class = num_class
18
+
19
+ # input = (b, 1, 31, 513), b = batch size
20
+ self.conv_block = nn.Sequential(
21
+ nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513)
22
+ nn.BatchNorm2d(num_features=64),
23
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
24
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
25
+ )
26
+
27
+ # res blocks
28
+ self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128)
29
+ self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32)
30
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
31
+
32
+ # pool block
33
+ self.pool_block = nn.Sequential(
34
+ nn.BatchNorm2d(num_features=256),
35
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
36
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
37
+ nn.Dropout(p=0.2),
38
+ )
39
+
40
+ # maxpool layers (for auxiliary network inputs)
41
+ # in = (b, 64, 31, 128) from conv_block, out = (b, 64, 31, 4)
42
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 32)) # 128 / 32 = 4
43
+ # in = (b, 128, 31, 64) from res_block1, out = (b, 128, 31, 4)
44
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 16)) # 64 / 16 = 4
45
+ # in = (b, 192, 31, 32) from res_block2, out = (b, 192, 31, 4)
46
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 8)) # 32 / 8 = 4
47
+
48
+
49
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
50
+ self.detector_conv = nn.Sequential(
51
+ nn.Conv2d(640, 256, 1, bias=False),
52
+ nn.BatchNorm2d(256),
53
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
54
+ nn.Dropout(p=0.2),
55
+ )
56
+
57
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
58
+ self.bilstm_classifier = nn.LSTM(
59
+ input_size=1024, hidden_size=256,
60
+ batch_first=True, bidirectional=True) # (b, 31, 512)
61
+
62
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
63
+ self.bilstm_detector = nn.LSTM(
64
+ input_size=1024, hidden_size=256,
65
+ batch_first=True, bidirectional=True) # (b, 31, 512)
66
+
67
+ # input: (b * 31, 512)
68
+ self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class)
69
+
70
+ # input: (b * 31, 512)
71
+ self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier
72
+
73
+ # initialize weights
74
+ self.apply(self.init_weights)
75
+
76
+ def get_feature_GAN(self, x):
77
+ seq_len = x.shape[-2]
78
+ x = x.float().transpose(-1, -2)
79
+
80
+ convblock_out = self.conv_block(x)
81
+
82
+ resblock1_out = self.res_block1(convblock_out)
83
+ resblock2_out = self.res_block2(resblock1_out)
84
+ resblock3_out = self.res_block3(resblock2_out)
85
+ poolblock_out = self.pool_block[0](resblock3_out)
86
+ poolblock_out = self.pool_block[1](poolblock_out)
87
+
88
+ return poolblock_out.transpose(-1, -2)
89
+
90
+ def get_feature(self, x):
91
+ seq_len = x.shape[-2]
92
+ x = x.float().transpose(-1, -2)
93
+
94
+ convblock_out = self.conv_block(x)
95
+
96
+ resblock1_out = self.res_block1(convblock_out)
97
+ resblock2_out = self.res_block2(resblock1_out)
98
+ resblock3_out = self.res_block3(resblock2_out)
99
+ poolblock_out = self.pool_block[0](resblock3_out)
100
+ poolblock_out = self.pool_block[1](poolblock_out)
101
+
102
+ return self.pool_block[2](poolblock_out)
103
+
104
+ def forward(self, x):
105
+ """
106
+ Returns:
107
+ classification_prediction, detection_prediction
108
+ sizes: (b, 31, 722), (b, 31, 2)
109
+ """
110
+ ###############################
111
+ # forward pass for classifier #
112
+ ###############################
113
+ seq_len = x.shape[-1]
114
+ x = x.float().transpose(-1, -2)
115
+
116
+ convblock_out = self.conv_block(x)
117
+
118
+ resblock1_out = self.res_block1(convblock_out)
119
+ resblock2_out = self.res_block2(resblock1_out)
120
+ resblock3_out = self.res_block3(resblock2_out)
121
+
122
+
123
+ poolblock_out = self.pool_block[0](resblock3_out)
124
+ poolblock_out = self.pool_block[1](poolblock_out)
125
+ GAN_feature = poolblock_out.transpose(-1, -2)
126
+ poolblock_out = self.pool_block[2](poolblock_out)
127
+
128
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
129
+ classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 1024))
130
+ classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states
131
+
132
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
133
+ classifier_out = self.classifier(classifier_out)
134
+ classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class)
135
+
136
+ # sizes: (b, 31, 722), (b, 31, 2)
137
+ # classifier output consists of predicted pitch classes per frame
138
+ # detector output consists of: (isvoice, notvoice) estimates per frame
139
+ return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
140
+
141
+ @staticmethod
142
+ def init_weights(m):
143
+ if isinstance(m, nn.Linear):
144
+ nn.init.kaiming_uniform_(m.weight)
145
+ if m.bias is not None:
146
+ nn.init.constant_(m.bias, 0)
147
+ elif isinstance(m, nn.Conv2d):
148
+ nn.init.xavier_normal_(m.weight)
149
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
150
+ for p in m.parameters():
151
+ if p.data is None:
152
+ continue
153
+
154
+ if len(p.shape) >= 2:
155
+ nn.init.orthogonal_(p.data)
156
+ else:
157
+ nn.init.normal_(p.data)
158
+
159
+
160
+ class ResBlock(nn.Module):
161
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
162
+ super().__init__()
163
+ self.downsample = in_channels != out_channels
164
+
165
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
166
+ self.pre_conv = nn.Sequential(
167
+ nn.BatchNorm2d(num_features=in_channels),
168
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
169
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
170
+ )
171
+
172
+ # conv layers
173
+ self.conv = nn.Sequential(
174
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
175
+ kernel_size=3, padding=1, bias=False),
176
+ nn.BatchNorm2d(out_channels),
177
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
178
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
179
+ )
180
+
181
+ # 1 x 1 convolution layer to match the feature dimensions
182
+ self.conv1by1 = None
183
+ if self.downsample:
184
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
185
+
186
+ def forward(self, x):
187
+ x = self.pre_conv(x)
188
+ if self.downsample:
189
+ x = self.conv(x) + self.conv1by1(x)
190
+ else:
191
+ x = self.conv(x) + x
192
+ return x
RingFormer/Utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
RingFormer/Utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (150 Bytes). View file
 
RingFormer/Utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (125 Bytes). View file
 
RingFormer/Utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (125 Bytes). View file
 
RingFormer/__pycache__/conformer.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
RingFormer/__pycache__/env.cpython-311.pyc ADDED
Binary file (1.32 kB). View file
 
RingFormer/__pycache__/meldataset.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
RingFormer/__pycache__/models.cpython-311.pyc ADDED
Binary file (49.7 kB). View file
 
RingFormer/__pycache__/norm2d.cpython-311.pyc ADDED
Binary file (4.6 kB). View file
 
RingFormer/__pycache__/stft.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
RingFormer/__pycache__/utils.cpython-311.pyc ADDED
Binary file (3.46 kB). View file
 
RingFormer/config_v1.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "F0_path": "/home/ubuntu/Darya_Speech/Utils/JDC/epoch_00079.pth",
3
+
4
+ "resblock": "1",
5
+ "num_gpus": 1,
6
+ "batch_size": 18,
7
+ "learning_rate": 0.0002,
8
+ "adam_b1": 0.8,
9
+ "adam_b2": 0.99,
10
+ "lr_decay": 0.999,
11
+ "seed": 1234,
12
+
13
+
14
+ "upsample_rates": [16,8],
15
+ "upsample_kernel_sizes": [32, 16],
16
+ "upsample_initial_channel": 512,
17
+ "resblock_kernel_sizes": [3,7,11],
18
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
19
+ "gen_istft_n_fft": 32,
20
+ "gen_istft_hop_size": 4,
21
+
22
+
23
+ "segment_size": 65536,
24
+ "num_mels": 128,
25
+ "n_fft": 2048,
26
+ "hop_size": 512,
27
+ "win_size": 2048,
28
+
29
+ "sampling_rate": 44100,
30
+
31
+ "fmin": 0,
32
+ "fmax": null,
33
+ "fmax_for_loss": null,
34
+
35
+ "num_workers": 8,
36
+
37
+ "dist_config": {
38
+ "dist_backend": "nccl",
39
+ "dist_url": "tcp://localhost:54321",
40
+ "world_size": 1
41
+ }
42
+ }
RingFormer/conformer.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from einops.layers.torch import Rearrange
6
+ from ring_attention_pytorch import RingAttention
7
+
8
+ # helper functions
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def default(val, d):
16
+ return val if exists(val) else d
17
+
18
+
19
+ def calc_same_padding(kernel_size):
20
+ pad = kernel_size // 2
21
+ return (pad, pad - (kernel_size + 1) % 2)
22
+
23
+
24
+ # helper classes
25
+
26
+
27
+ class Swish(nn.Module):
28
+ def forward(self, x):
29
+ return x * x.sigmoid()
30
+
31
+
32
+ class GLU(nn.Module):
33
+ def __init__(self, dim):
34
+ super().__init__()
35
+ self.dim = dim
36
+
37
+ def forward(self, x):
38
+ out, gate = x.chunk(2, dim=self.dim)
39
+ return out * gate.sigmoid()
40
+
41
+
42
+ class DepthWiseConv1d(nn.Module):
43
+ def __init__(self, chan_in, chan_out, kernel_size, padding):
44
+ super().__init__()
45
+ self.padding = padding
46
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
47
+
48
+ def forward(self, x):
49
+ x = F.pad(x, self.padding)
50
+ return self.conv(x)
51
+
52
+
53
+ # attention, feedforward, and conv module
54
+
55
+
56
+ class Scale(nn.Module):
57
+ def __init__(self, scale, fn):
58
+ super().__init__()
59
+ self.fn = fn
60
+ self.scale = scale
61
+
62
+ def forward(self, x, **kwargs):
63
+ return self.fn(x, **kwargs) * self.scale
64
+
65
+
66
+ class PreNorm(nn.Module):
67
+ def __init__(self, dim, fn):
68
+ super().__init__()
69
+ self.fn = fn
70
+ self.norm = nn.LayerNorm(dim)
71
+
72
+ def forward(self, x, **kwargs):
73
+
74
+ x = self.norm(x.to(x.device))
75
+
76
+ out = self.fn(x.to(x.device), **kwargs)
77
+
78
+ return out
79
+
80
+
81
+ class FeedForward(nn.Module):
82
+ def __init__(self, dim, mult=4, dropout=0.0):
83
+ super().__init__()
84
+ self.net = nn.Sequential(
85
+ nn.Linear(dim, dim * mult),
86
+ Swish(),
87
+ nn.Dropout(dropout),
88
+ nn.Linear(dim * mult, dim),
89
+ nn.Dropout(dropout),
90
+ )
91
+
92
+ def forward(self, x):
93
+ return self.net(x)
94
+
95
+
96
+ class ConformerConvModule(nn.Module):
97
+ def __init__(
98
+ self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0
99
+ ):
100
+ super().__init__()
101
+
102
+ inner_dim = dim * expansion_factor
103
+ padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
104
+
105
+ self.net = nn.Sequential(
106
+ nn.LayerNorm(dim),
107
+ Rearrange("b n c -> b c n"),
108
+ nn.Conv1d(dim, inner_dim * 2, 1),
109
+ GLU(dim=1),
110
+ DepthWiseConv1d(
111
+ inner_dim, inner_dim, kernel_size=kernel_size, padding=padding
112
+ ),
113
+ nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
114
+ Swish(),
115
+ nn.Conv1d(inner_dim, dim, 1),
116
+ Rearrange("b c n -> b n c"),
117
+ nn.Dropout(dropout),
118
+ )
119
+
120
+ def forward(self, x):
121
+ return self.net(x)
122
+
123
+
124
+ # Conformer Block
125
+
126
+
127
+ class ConformerBlock(nn.Module):
128
+ def __init__(
129
+ self,
130
+ *,
131
+ dim,
132
+ dim_head=64,
133
+ heads=8,
134
+ ff_mult=4,
135
+ conv_expansion_factor=2,
136
+ conv_kernel_size=31,
137
+ attn_dropout=0.0,
138
+ ff_dropout=0.0,
139
+ conv_dropout=0.0,
140
+ conv_causal=False
141
+ ):
142
+ super().__init__()
143
+ self.ff1 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
144
+ self.attn = RingAttention(
145
+ dim=dim,
146
+ dim_head=dim_head,
147
+ heads=heads,
148
+ causal=True,
149
+ auto_shard_seq=False, # doesn't work on multi-gpu setup for some reason
150
+ ring_attn=True,
151
+ ring_seq_size=512,
152
+ )
153
+ self.self_attn_dropout = torch.nn.Dropout(attn_dropout)
154
+ self.conv = ConformerConvModule(
155
+ dim=dim,
156
+ causal=conv_causal,
157
+ expansion_factor=conv_expansion_factor,
158
+ kernel_size=conv_kernel_size,
159
+ dropout=conv_dropout,
160
+ )
161
+ self.ff2 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
162
+
163
+ self.attn = PreNorm(dim, self.attn)
164
+ self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
165
+ self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
166
+
167
+ self.post_norm = nn.LayerNorm(dim)
168
+
169
+
170
+ def forward(self, x, mask=None):
171
+ x_ff1 = self.ff1(x) + x
172
+
173
+ x = self.attn(x, mask=mask)
174
+ x = self.self_attn_dropout(x)
175
+ x = x + x_ff1
176
+ x = self.conv(x) + x
177
+ x = self.ff2(x) + x
178
+ x = self.post_norm(x)
179
+ return x
180
+
181
+
182
+
183
+ # Conformer
184
+
185
+
186
+ class Conformer(nn.Module):
187
+ def __init__(
188
+ self,
189
+
190
+ dim,
191
+ *,
192
+ depth,
193
+ dim_head=64,
194
+ heads=8,
195
+ ff_mult=4,
196
+ conv_expansion_factor=2,
197
+ conv_kernel_size=31,
198
+ attn_dropout=0.0,
199
+ ff_dropout=0.0,
200
+ conv_dropout=0.0,
201
+ conv_causal=False
202
+ ):
203
+ super().__init__()
204
+ self.dim = dim
205
+
206
+ self.layers = nn.ModuleList([])
207
+
208
+ for _ in range(depth):
209
+ self.layers.append(
210
+ ConformerBlock(
211
+ dim=dim,
212
+ dim_head=dim_head,
213
+ heads=heads,
214
+ ff_mult=ff_mult,
215
+ conv_expansion_factor=conv_expansion_factor,
216
+ conv_kernel_size=conv_kernel_size,
217
+ conv_causal=conv_causal,
218
+ )
219
+ )
220
+
221
+
222
+ def forward(self, x):
223
+
224
+ for block in self.layers:
225
+
226
+ x = block(x)
227
+
228
+ return x
RingFormer/env.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+
5
+ class AttrDict(dict):
6
+ def __init__(self, *args, **kwargs):
7
+ super(AttrDict, self).__init__(*args, **kwargs)
8
+ self.__dict__ = self
9
+
10
+
11
+ def build_env(config, config_name, path):
12
+ t_path = os.path.join(path, config_name)
13
+ if config != t_path:
14
+ os.makedirs(path, exist_ok=True)
15
+ shutil.copyfile(config, os.path.join(path, config_name))
RingFormer/inference.ipynb ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "7b82eb58",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/home/ubuntu/miniconda3/envs/respair/lib/python3.11/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
14
+ " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
15
+ ]
16
+ },
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "/home/ubuntu/RINGFORMER\n"
22
+ ]
23
+ }
24
+ ],
25
+ "source": [
26
+ "%cd /home/ubuntu/RINGFORMER\n",
27
+ "\n",
28
+ "from __future__ import absolute_import, division, print_function, unicode_literals\n",
29
+ "\n",
30
+ "import glob\n",
31
+ "import os\n",
32
+ "import argparse\n",
33
+ "import json\n",
34
+ "import torch\n",
35
+ "from scipy.io.wavfile import write\n",
36
+ "from env import AttrDict\n",
37
+ "from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav\n",
38
+ "from models import Generator\n",
39
+ "from stft import TorchSTFT\n",
40
+ "\n",
41
+ "from Utils.JDC.model import JDCNet"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 2,
47
+ "id": "3ee13ffd",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "h = None\n",
52
+ "device = None\n",
53
+ "\n",
54
+ "\n",
55
+ "def load_checkpoint(filepath, device):\n",
56
+ " assert os.path.isfile(filepath)\n",
57
+ " print(\"Loading '{}'\".format(filepath))\n",
58
+ " checkpoint_dict = torch.load(filepath, map_location=device)\n",
59
+ " print(\"Complete.\")\n",
60
+ " return checkpoint_dict\n",
61
+ "\n",
62
+ "\n",
63
+ "def get_mel(x):\n",
64
+ " return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)\n",
65
+ "\n",
66
+ "\n",
67
+ "def scan_checkpoint(cp_dir, prefix):\n",
68
+ " pattern = os.path.join(cp_dir, prefix + '*')\n",
69
+ " cp_list = glob.glob(pattern)\n",
70
+ " if len(cp_list) == 0:\n",
71
+ " return ''\n",
72
+ " return sorted(cp_list)[-1]"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 3,
78
+ "id": "003b1249",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "F0_model = JDCNet(num_class=1, seq_len=192)"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": 4,
88
+ "id": "321eb3b5",
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "cp_path = \"/home/ubuntu/RINGFORMER/cp_ringformer_44.1khz\""
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": 5,
98
+ "id": "3dcc2764",
99
+ "metadata": {},
100
+ "outputs": [
101
+ {
102
+ "data": {
103
+ "text/plain": [
104
+ "{'F0_path': '/home/ubuntu/Darya_Speech/Utils/JDC/epoch_00079.pth',\n",
105
+ " 'resblock': '1',\n",
106
+ " 'num_gpus': 1,\n",
107
+ " 'batch_size': 18,\n",
108
+ " 'learning_rate': 0.0002,\n",
109
+ " 'adam_b1': 0.8,\n",
110
+ " 'adam_b2': 0.99,\n",
111
+ " 'lr_decay': 0.999,\n",
112
+ " 'seed': 1234,\n",
113
+ " 'upsample_rates': [16, 8],\n",
114
+ " 'upsample_kernel_sizes': [32, 16],\n",
115
+ " 'upsample_initial_channel': 512,\n",
116
+ " 'resblock_kernel_sizes': [3, 7, 11],\n",
117
+ " 'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n",
118
+ " 'gen_istft_n_fft': 32,\n",
119
+ " 'gen_istft_hop_size': 4,\n",
120
+ " 'segment_size': 65536,\n",
121
+ " 'num_mels': 128,\n",
122
+ " 'n_fft': 2048,\n",
123
+ " 'hop_size': 512,\n",
124
+ " 'win_size': 2048,\n",
125
+ " 'sampling_rate': 44100,\n",
126
+ " 'fmin': 0,\n",
127
+ " 'fmax': None,\n",
128
+ " 'fmax_for_loss': None,\n",
129
+ " 'num_workers': 8,\n",
130
+ " 'dist_config': {'dist_backend': 'nccl',\n",
131
+ " 'dist_url': 'tcp://localhost:54321',\n",
132
+ " 'world_size': 1}}"
133
+ ]
134
+ },
135
+ "execution_count": 5,
136
+ "metadata": {},
137
+ "output_type": "execute_result"
138
+ }
139
+ ],
140
+ "source": [
141
+ "with open(cp_path + \"/config.json\") as f:\n",
142
+ " data = f.read()\n",
143
+ "\n",
144
+ "json_config = json.loads(data)\n",
145
+ "h = AttrDict(json_config)\n",
146
+ "h"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": 6,
152
+ "id": "a4c78cb6",
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": [
156
+ "# device = torch.device('cuda:{:d}'.format(0))\n",
157
+ "\n",
158
+ "device = 'cuda:0'"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 7,
164
+ "id": "5a782adb",
165
+ "metadata": {},
166
+ "outputs": [
167
+ {
168
+ "name": "stderr",
169
+ "output_type": "stream",
170
+ "text": [
171
+ "/home/ubuntu/miniconda3/envs/respair/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n",
172
+ " WeightNorm.apply(module, name, dim)\n"
173
+ ]
174
+ }
175
+ ],
176
+ "source": [
177
+ "generator = Generator(h, F0_model).to(device)\n",
178
+ "stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 8,
184
+ "id": "6f0a7c64",
185
+ "metadata": {},
186
+ "outputs": [
187
+ {
188
+ "name": "stdout",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "Loading '/home/ubuntu/RINGFORMER/cp_ringformer_44.1khz/g_00017000'\n"
192
+ ]
193
+ },
194
+ {
195
+ "name": "stderr",
196
+ "output_type": "stream",
197
+ "text": [
198
+ "/tmp/ipykernel_3972638/3295719764.py:8: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
199
+ " checkpoint_dict = torch.load(filepath, map_location=device)\n"
200
+ ]
201
+ },
202
+ {
203
+ "name": "stdout",
204
+ "output_type": "stream",
205
+ "text": [
206
+ "Complete.\n",
207
+ "Removing weight norm...\n"
208
+ ]
209
+ }
210
+ ],
211
+ "source": [
212
+ "cp_g = scan_checkpoint(cp_path, 'g_')\n",
213
+ "state_dict_g = load_checkpoint(cp_g, device)\n",
214
+ "generator.load_state_dict(state_dict_g['generator'])\n",
215
+ "generator.remove_weight_norm()\n",
216
+ "_ = generator.eval()"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "markdown",
221
+ "id": "a115a967",
222
+ "metadata": {},
223
+ "source": [
224
+ "### Resynthesis"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "id": "cbeee500",
231
+ "metadata": {},
232
+ "outputs": [],
233
+ "source": [
234
+ "import torchaudio\n",
235
+ "from librosa.filters import mel as librosa_mel_fn\n",
236
+ "from IPython.display import Audio\n",
237
+ "import librosa\n",
238
+ "\n",
239
+ "to_mel = torchaudio.transforms.MelSpectrogram(\n",
240
+ " n_mels=128, n_fft=2048, win_length=2048, hop_length=512, sample_rate=44100, power=2.5)\n",
241
+ "mean, std = -4, 4\n",
242
+ "\n",
243
+ "def preprocess(wave):\n",
244
+ " wave_tensor = torch.FloatTensor(wav)\n",
245
+ " mel_tensor = to_mel(wave_tensor)\n",
246
+ " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
247
+ " return mel_tensor.to('cuda:0')\n",
248
+ "\n",
249
+ "\n",
250
+ "wav = librosa.load(\"/your.wav\", sr=44100)[0]\n",
251
+ "\n",
252
+ "x = preprocess(wav)\n",
253
+ "print(x.shape)\n",
254
+ "\n",
255
+ "n = 1\n",
256
+ "xxx = torch.load(\"/home/ubuntu/RINGFORMER/gt.pt\").to('cuda:0')[n:n+1,:,:]\n",
257
+ "with torch.no_grad():\n",
258
+ "\n",
259
+ " spec, phase = generator(xxx)\n",
260
+ " y_g_hat = stft.inverse(spec, phase)\n",
261
+ " audio = y_g_hat.squeeze()\n",
262
+ " # audio = audio * MAX_WAV_VALUE\n",
263
+ " audio = audio.cpu().numpy()\n",
264
+ "\n",
265
+ "\n",
266
+ "print('Synthesized:')\n",
267
+ "display(Audio(audio, rate=44100))"
268
+ ]
269
+ }
270
+ ],
271
+ "metadata": {
272
+ "kernelspec": {
273
+ "display_name": "respair",
274
+ "language": "python",
275
+ "name": "python3"
276
+ },
277
+ "language_info": {
278
+ "codemirror_mode": {
279
+ "name": "ipython",
280
+ "version": 3
281
+ },
282
+ "file_extension": ".py",
283
+ "mimetype": "text/x-python",
284
+ "name": "python",
285
+ "nbconvert_exporter": "python",
286
+ "pygments_lexer": "ipython3",
287
+ "version": "3.11.0"
288
+ }
289
+ },
290
+ "nbformat": 4,
291
+ "nbformat_minor": 5
292
+ }
RingFormer/meldataset.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ import numpy as np
7
+ from librosa.util import normalize
8
+ from scipy.io.wavfile import read
9
+ import torchaudio
10
+ import librosa
11
+ from librosa.filters import mel as librosa_mel_fn
12
+
13
+ MAX_WAV_VALUE = 32768.0
14
+ import soundfile as sf
15
+
16
+
17
+ def normalize_audio(wav):
18
+ return wav / torch.max(torch.abs(torch.from_numpy(wav))) # Correct peak normalization
19
+
20
+ def load_wav(full_path):
21
+ data, sampling_rate = librosa.load(full_path, sr=None)
22
+ return data, sampling_rate
23
+
24
+
25
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
26
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
27
+
28
+
29
+ def dynamic_range_decompression(x, C=1):
30
+ return np.exp(x) / C
31
+
32
+
33
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
34
+ return torch.log(torch.clamp(x, min=clip_val) * C)
35
+
36
+
37
+ def dynamic_range_decompression_torch(x, C=1):
38
+ return torch.exp(x) / C
39
+
40
+
41
+ def spectral_normalize_torch(magnitudes):
42
+ output = dynamic_range_compression_torch(magnitudes)
43
+ return output
44
+
45
+
46
+ def spectral_de_normalize_torch(magnitudes):
47
+ output = dynamic_range_decompression_torch(magnitudes)
48
+ return output
49
+
50
+
51
+ mel_basis = {}
52
+ hann_window = {}
53
+
54
+
55
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
56
+
57
+ y = torch.clamp(y, min=-1, max=1)
58
+
59
+ if torch.min(y) < -1.:
60
+ print('min value is ', torch.min(y))
61
+ if torch.max(y) > 1.:
62
+ print('max value is ', torch.max(y))
63
+
64
+ global mel_basis, hann_window
65
+ if fmax not in mel_basis:
66
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
67
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
68
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
69
+
70
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
71
+ y = y.squeeze(1)
72
+
73
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
74
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
75
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
76
+ spec = torch.view_as_real(spec)
77
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
78
+
79
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
80
+ spec = spectral_normalize_torch(spec)
81
+
82
+ return spec
83
+
84
+
85
+
86
+ to_mel = torchaudio.transforms.MelSpectrogram(
87
+ sample_rate=44_100, n_mels=128, n_fft=2048, win_length=2048, hop_length=512)
88
+
89
+
90
+
91
+ # to_mel = torchaudio.transforms.MelSpectrogram(
92
+ # sample_rate=24000, n_mels=80, n_fft=2048, win_length=1200, hop_length=300, center='center')
93
+
94
+ mean, std = -4, 4
95
+
96
+ def preproces(wave,to_mel=to_mel, device='cpu'):
97
+
98
+ to_mel = to_mel.to(device)
99
+ # wave_tensor = torch.from_numpy(wave).float()
100
+ mel_tensor = to_mel(wave)
101
+ mel_tensor = (torch.log(1e-5 + mel_tensor) - mean) / std
102
+ return mel_tensor
103
+
104
+
105
+ def get_dataset_filelist(a):
106
+ with open(a.input_training_file, 'r', encoding='utf-8') as fi:
107
+ training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ('' if '' not in x else ''))
108
+ for x in fi.read().split('\n') if len(x) > 0]
109
+
110
+ with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
111
+ validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ('' if '' not in x else ''))
112
+ for x in fi.read().split('\n') if len(x) > 0]
113
+ return training_files, validation_files
114
+
115
+
116
+ class MelDataset(torch.utils.data.Dataset):
117
+ def __init__(self, training_files, segment_size, n_fft, num_mels,
118
+ hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
119
+ device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
120
+ self.audio_files = training_files
121
+ random.seed(1234)
122
+ if shuffle:
123
+ random.shuffle(self.audio_files)
124
+ self.segment_size = segment_size
125
+ self.sampling_rate = sampling_rate
126
+ self.split = split
127
+ self.n_fft = n_fft
128
+ self.num_mels = num_mels
129
+ self.hop_size = hop_size
130
+ self.win_size = win_size
131
+ self.fmin = fmin
132
+ self.fmax = fmax
133
+ self.fmax_loss = fmax_loss
134
+ self.cached_wav = None
135
+ self.n_cache_reuse = n_cache_reuse
136
+ self._cache_ref_count = 0
137
+ self.device = device
138
+ self.fine_tuning = fine_tuning
139
+ self.base_mels_path = base_mels_path
140
+
141
+ def __getitem__(self, index):
142
+ filename = self.audio_files[index]
143
+ if self._cache_ref_count == 0:
144
+ audio, sampling_rate = load_wav(filename)
145
+
146
+
147
+ self.cached_wav = audio
148
+ if sampling_rate != self.sampling_rate:
149
+ raise ValueError("{} SR doesn't match target {} SR".format(
150
+ sampling_rate, self.sampling_rate))
151
+ self._cache_ref_count = self.n_cache_reuse
152
+ else:
153
+ audio = self.cached_wav
154
+ self._cache_ref_count -= 1
155
+
156
+ audio = torch.FloatTensor(audio)
157
+ audio = audio.unsqueeze(0)
158
+
159
+ if not self.fine_tuning:
160
+ if self.split:
161
+ if audio.size(1) >= self.segment_size:
162
+ max_audio_start = audio.size(1) - self.segment_size
163
+ audio_start = random.randint(0, max_audio_start)
164
+ audio = audio[:, audio_start:audio_start+self.segment_size]
165
+ else:
166
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
167
+
168
+ # mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
169
+ # self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
170
+ # center=False)
171
+
172
+ mel = preproces(audio)
173
+ else:
174
+ mel = np.load(
175
+ os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
176
+ mel = torch.from_numpy(mel)
177
+
178
+ if len(mel.shape) < 3:
179
+ mel = mel.unsqueeze(0)
180
+
181
+ if self.split:
182
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
183
+
184
+ if audio.size(1) >= self.segment_size:
185
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
186
+ mel = mel[:, :, mel_start:mel_start + frames_per_seg]
187
+ audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
188
+ else:
189
+ mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
190
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
191
+
192
+ mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
193
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
194
+ center=False)
195
+
196
+ # mel_loss = mel_spectrogram(audio)
197
+ if mel.shape[-1] != mel_loss.shape[-1]:
198
+ mel = mel[..., :mel_loss.shape[-1]]
199
+
200
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
201
+
202
+ def __len__(self):
203
+ return len(self.audio_files)
RingFormer/models.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from utils import init_weights, get_padding
7
+ import numpy as np
8
+ from stft import TorchSTFT
9
+ import torchaudio
10
+ from nnAudio import features
11
+ from einops import rearrange
12
+ from norm2d import NormConv2d
13
+ from utils import get_padding
14
+ from munch import Munch
15
+ from conformer import Conformer
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ def get_2d_padding(kernel_size, dilation=(1, 1)):
21
+ return (
22
+ ((kernel_size[0] - 1) * dilation[0]) // 2,
23
+ ((kernel_size[1] - 1) * dilation[1]) // 2,
24
+ )
25
+
26
+
27
+
28
+ class ResBlock1(torch.nn.Module):
29
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
30
+ super(ResBlock1, self).__init__()
31
+ self.h = h
32
+ self.convs1 = nn.ModuleList([
33
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
34
+ padding=get_padding(kernel_size, dilation[0]))),
35
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
36
+ padding=get_padding(kernel_size, dilation[1]))),
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
38
+ padding=get_padding(kernel_size, dilation[2])))
39
+ ])
40
+ self.convs1.apply(init_weights)
41
+
42
+ self.convs2 = nn.ModuleList([
43
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
44
+ padding=get_padding(kernel_size, 1))),
45
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
+ padding=get_padding(kernel_size, 1))),
47
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
+ padding=get_padding(kernel_size, 1)))
49
+ ])
50
+ self.convs2.apply(init_weights)
51
+
52
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
53
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
54
+
55
+
56
+ def forward(self, x):
57
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.alpha1, self.alpha2):
58
+ xt = x + (1 / a1) * (torch.sin(a1 * x) ** 2) # Snake1D
59
+ xt = c1(xt)
60
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
61
+ xt = c2(xt)
62
+ x = xt + x
63
+ return x
64
+
65
+ def remove_weight_norm(self):
66
+ for l in self.convs1:
67
+ remove_weight_norm(l)
68
+ for l in self.convs2:
69
+ remove_weight_norm(l)
70
+
71
+ class ResBlock1_old(torch.nn.Module):
72
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
73
+ super(ResBlock1, self).__init__()
74
+ self.h = h
75
+ self.convs1 = nn.ModuleList([
76
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
77
+ padding=get_padding(kernel_size, dilation[0]))),
78
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
79
+ padding=get_padding(kernel_size, dilation[1]))),
80
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
81
+ padding=get_padding(kernel_size, dilation[2])))
82
+ ])
83
+ self.convs1.apply(init_weights)
84
+
85
+ self.convs2 = nn.ModuleList([
86
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
87
+ padding=get_padding(kernel_size, 1))),
88
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
89
+ padding=get_padding(kernel_size, 1))),
90
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
91
+ padding=get_padding(kernel_size, 1)))
92
+ ])
93
+ self.convs2.apply(init_weights)
94
+
95
+ def forward(self, x):
96
+ for c1, c2 in zip(self.convs1, self.convs2):
97
+ xt = F.leaky_relu(x, LRELU_SLOPE)
98
+ xt = c1(xt)
99
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
100
+ xt = c2(xt)
101
+ x = xt + x
102
+ return x
103
+
104
+ def remove_weight_norm(self):
105
+ for l in self.convs1:
106
+ remove_weight_norm(l)
107
+ for l in self.convs2:
108
+ remove_weight_norm(l)
109
+
110
+
111
+ class ResBlock2(torch.nn.Module):
112
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
113
+ super(ResBlock2, self).__init__()
114
+ self.h = h
115
+ self.convs = nn.ModuleList([
116
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
117
+ padding=get_padding(kernel_size, dilation[0]))),
118
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
119
+ padding=get_padding(kernel_size, dilation[1])))
120
+ ])
121
+ self.convs.apply(init_weights)
122
+
123
+ def forward(self, x):
124
+ for c in self.convs:
125
+ xt = F.leaky_relu(x, LRELU_SLOPE)
126
+ xt = c(xt)
127
+ x = xt + x
128
+ return x
129
+
130
+ def remove_weight_norm(self):
131
+ for l in self.convs:
132
+ remove_weight_norm(l)
133
+
134
+
135
+ class SineGen(torch.nn.Module):
136
+ """ Definition of sine generator
137
+ SineGen(samp_rate, harmonic_num = 0,
138
+ sine_amp = 0.1, noise_std = 0.003,
139
+ voiced_threshold = 0,
140
+ flag_for_pulse=False)
141
+ samp_rate: sampling rate in Hz
142
+ harmonic_num: number of harmonic overtones (default 0)
143
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
144
+ noise_std: std of Gaussian noise (default 0.003)
145
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
146
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
147
+ Note: when flag_for_pulse is True, the first time step of a voiced
148
+ segment is always sin(np.pi) or cos(0)
149
+ """
150
+
151
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
152
+ sine_amp=0.1, noise_std=0.003,
153
+ voiced_threshold=0,
154
+ flag_for_pulse=False):
155
+ super(SineGen, self).__init__()
156
+ self.sine_amp = sine_amp
157
+ self.noise_std = noise_std
158
+ self.harmonic_num = harmonic_num
159
+ self.dim = self.harmonic_num + 1
160
+ self.sampling_rate = samp_rate
161
+ self.voiced_threshold = voiced_threshold
162
+ self.flag_for_pulse = flag_for_pulse
163
+ self.upsample_scale = upsample_scale
164
+
165
+ def _f02uv(self, f0):
166
+ # generate uv signal
167
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
168
+ return uv
169
+
170
+ def _f02sine(self, f0_values):
171
+ """ f0_values: (batchsize, length, dim)
172
+ where dim indicates fundamental tone and overtones
173
+ """
174
+ # convert to F0 in rad. The interger part n can be ignored
175
+ # because 2 * np.pi * n doesn't affect phase
176
+ rad_values = (f0_values / self.sampling_rate) % 1
177
+
178
+ # initial phase noise (no noise for fundamental component)
179
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
180
+ device=f0_values.device)
181
+ rand_ini[:, 0] = 0
182
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
183
+
184
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
185
+ if not self.flag_for_pulse:
186
+ # # for normal case
187
+
188
+ # # To prevent torch.cumsum numerical overflow,
189
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
190
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
191
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
192
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
193
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
194
+ # cumsum_shift = torch.zeros_like(rad_values)
195
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
196
+
197
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
198
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
199
+ scale_factor=1/self.upsample_scale,
200
+ mode="linear").transpose(1, 2)
201
+
202
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
203
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
204
+ # cumsum_shift = torch.zeros_like(rad_values)
205
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
206
+
207
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
208
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
209
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
210
+ sines = torch.sin(phase)
211
+
212
+ else:
213
+ # If necessary, make sure that the first time step of every
214
+ # voiced segments is sin(pi) or cos(0)
215
+ # This is used for pulse-train generation
216
+
217
+ # identify the last time step in unvoiced segments
218
+ uv = self._f02uv(f0_values)
219
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
220
+ uv_1[:, -1, :] = 1
221
+ u_loc = (uv < 1) * (uv_1 > 0)
222
+
223
+ # get the instantanouse phase
224
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
225
+ # different batch needs to be processed differently
226
+ for idx in range(f0_values.shape[0]):
227
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
228
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
229
+ # stores the accumulation of i.phase within
230
+ # each voiced segments
231
+ tmp_cumsum[idx, :, :] = 0
232
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
233
+
234
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
235
+ # within the previous voiced segment.
236
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
237
+
238
+ # get the sines
239
+ sines = torch.cos(i_phase * 2 * np.pi)
240
+ return sines
241
+
242
+ def forward(self, f0):
243
+ """ sine_tensor, uv = forward(f0)
244
+ input F0: tensor(batchsize=1, length, dim=1)
245
+ f0 for unvoiced steps should be 0
246
+ output sine_tensor: tensor(batchsize=1, length, dim)
247
+ output uv: tensor(batchsize=1, length, 1)
248
+ """
249
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
250
+ device=f0.device)
251
+ # fundamental component
252
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
253
+
254
+ # generate sine waveforms
255
+ sine_waves = self._f02sine(fn) * self.sine_amp
256
+
257
+ # generate uv signal
258
+ # uv = torch.ones(f0.shape)
259
+ # uv = uv * (f0 > self.voiced_threshold)
260
+ uv = self._f02uv(f0)
261
+
262
+ # noise: for unvoiced should be similar to sine_amp
263
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
264
+ # . for voiced regions is self.noise_std
265
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
266
+ noise = noise_amp * torch.randn_like(sine_waves)
267
+
268
+ # first: set the unvoiced part to 0 by uv
269
+ # then: additive noise
270
+ sine_waves = sine_waves * uv + noise
271
+ return sine_waves, uv, noise
272
+
273
+
274
+ class SourceModuleHnNSF(torch.nn.Module):
275
+ """ SourceModule for hn-nsf
276
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
277
+ add_noise_std=0.003, voiced_threshod=0)
278
+ sampling_rate: sampling_rate in Hz
279
+ harmonic_num: number of harmonic above F0 (default: 0)
280
+ sine_amp: amplitude of sine source signal (default: 0.1)
281
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
282
+ note that amplitude of noise in unvoiced is decided
283
+ by sine_amp
284
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
285
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
286
+ F0_sampled (batchsize, length, 1)
287
+ Sine_source (batchsize, length, 1)
288
+ noise_source (batchsize, length 1)
289
+ uv (batchsize, length, 1)
290
+ """
291
+
292
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
293
+ add_noise_std=0.003, voiced_threshod=0):
294
+ super(SourceModuleHnNSF, self).__init__()
295
+
296
+ self.sine_amp = sine_amp
297
+ self.noise_std = add_noise_std
298
+
299
+ # to produce sine waveforms
300
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
301
+ sine_amp, add_noise_std, voiced_threshod)
302
+
303
+ # to merge source harmonics into a single excitation
304
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
305
+ self.l_tanh = torch.nn.Tanh()
306
+
307
+ def forward(self, x):
308
+ """
309
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
310
+ F0_sampled (batchsize, length, 1)
311
+ Sine_source (batchsize, length, 1)
312
+ noise_source (batchsize, length 1)
313
+ """
314
+ # source for harmonic branch
315
+ with torch.no_grad():
316
+ sine_wavs, uv, _ = self.l_sin_gen(x)
317
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
318
+
319
+ # source for noise branch, in the same shape as uv
320
+ noise = torch.randn_like(uv) * self.sine_amp / 3
321
+ return sine_merge, noise, uv
322
+ def padDiff(x):
323
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
324
+
325
+
326
+
327
+ class Generator(torch.nn.Module):
328
+ def __init__(self, h, F0_model):
329
+ super(Generator, self).__init__()
330
+ self.h = h
331
+ self.num_kernels = len(h.resblock_kernel_sizes)
332
+ self.num_upsamples = len(h.upsample_rates)
333
+ self.conv_pre = weight_norm(Conv1d(128, h.upsample_initial_channel, 7, 1, padding=3))
334
+
335
+
336
+
337
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
338
+
339
+ self.m_source = SourceModuleHnNSF(
340
+ sampling_rate=h.sampling_rate,
341
+ upsample_scale=np.prod(h.upsample_rates) * h.gen_istft_hop_size,
342
+ harmonic_num=8, voiced_threshod=10)
343
+
344
+ self.f0_upsamp = torch.nn.Upsample(
345
+ scale_factor=np.prod(h.upsample_rates) * h.gen_istft_hop_size)
346
+ self.noise_convs = nn.ModuleList()
347
+ self.noise_res = nn.ModuleList()
348
+
349
+ self.F0_model = F0_model
350
+
351
+ self.ups = nn.ModuleList()
352
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
353
+ self.ups.append(weight_norm(
354
+ ConvTranspose1d(h.upsample_initial_channel//(2**i),
355
+ h.upsample_initial_channel//(2**(i+1)),
356
+ k,
357
+ u,
358
+ padding=(k-u)//2)))
359
+
360
+ c_cur = h.upsample_initial_channel // (2 ** (i + 1))
361
+
362
+ if i + 1 < len(h.upsample_rates): #
363
+ stride_f0 = np.prod(h.upsample_rates[i + 1:])
364
+ self.noise_convs.append(Conv1d(
365
+ h.gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
366
+ self.noise_res.append(resblock(h, c_cur, 7, [1,3,5]))
367
+ else:
368
+ self.noise_convs.append(Conv1d(h.gen_istft_n_fft + 2, c_cur, kernel_size=1))
369
+ self.noise_res.append(resblock(h, c_cur, 11, [1,3,5]))
370
+
371
+ self.alphas = nn.ParameterList()
372
+ self.alphas.append(nn.Parameter(torch.ones(1, h.upsample_initial_channel, 1)))
373
+ self.resblocks = nn.ModuleList()
374
+ for i in range(len(self.ups)):
375
+ ch = h.upsample_initial_channel//(2**(i+1))
376
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
377
+ for j, (k, d) in enumerate(
378
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
379
+ self.resblocks.append(resblock(h, ch, k, d))
380
+
381
+
382
+ self.conformers = nn.ModuleList()
383
+ self.post_n_fft = h.gen_istft_n_fft
384
+ self.conv_post = weight_norm(Conv1d(128, self.post_n_fft + 2, 7, 1, padding=3))
385
+
386
+ for i in range(len(self.ups)):
387
+ ch = h.upsample_initial_channel // (2**i)
388
+ self.conformers.append(
389
+ Conformer(
390
+ dim=ch,
391
+ depth=2,
392
+ dim_head=64,
393
+ heads=8,
394
+ ff_mult=4,
395
+ conv_expansion_factor=2,
396
+ conv_kernel_size=31,
397
+ attn_dropout=0.1,
398
+ ff_dropout=0.1,
399
+ conv_dropout=0.1,
400
+ # device=self.device
401
+ )
402
+ )
403
+
404
+ self.ups.apply(init_weights)
405
+ self.conv_post.apply(init_weights)
406
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
407
+ self.stft = TorchSTFT(filter_length=h.gen_istft_n_fft,
408
+ hop_length=h.gen_istft_hop_size,
409
+ win_length=h.gen_istft_n_fft)
410
+
411
+
412
+
413
+ def forward(self, x):
414
+
415
+
416
+
417
+ f0, _, _ = self.F0_model(x.unsqueeze(1))
418
+ if len(f0.shape) == 1:
419
+ f0 = f0.unsqueeze(0)
420
+
421
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
422
+
423
+ har_source, _, _ = self.m_source(f0)
424
+ har_source = har_source.transpose(1, 2).squeeze(1)
425
+ har_spec, har_phase = self.stft.transform(har_source)
426
+ har = torch.cat([har_spec, har_phase], dim=1)
427
+
428
+
429
+ x = self.conv_pre(x)
430
+
431
+ for i in range(self.num_upsamples):
432
+
433
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
434
+ x = rearrange(x, "b f t -> b t f")
435
+
436
+ x = self.conformers[i](x)
437
+
438
+ x = rearrange(x, "b t f -> b f t")
439
+
440
+ # x = F.leaky_relu(x, LRELU_SLOPE)
441
+ x_source = self.noise_convs[i](har)
442
+ x_source = self.noise_res[i](x_source)
443
+
444
+ x = self.ups[i](x)
445
+ if i == self.num_upsamples - 1:
446
+ x = self.reflection_pad(x)
447
+
448
+ x = x + x_source
449
+
450
+
451
+ xs = None
452
+ for j in range(self.num_kernels):
453
+ if xs is None:
454
+ xs = self.resblocks[i*self.num_kernels+j](x)
455
+ else:
456
+ xs += self.resblocks[i*self.num_kernels+j](x)
457
+ x = xs / self.num_kernels
458
+ # x = F.leaky_relu(x)
459
+
460
+
461
+ x = x + (1 / self.alphas[i + 1]) * (torch.sin(self.alphas[i + 1] * x) ** 2)
462
+
463
+ x = self.conv_post(x)
464
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]).to(x.device)
465
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :]).to(x.device)
466
+
467
+ return spec, phase
468
+
469
+ def remove_weight_norm(self):
470
+ print("Removing weight norm...")
471
+ for l in self.ups:
472
+ remove_weight_norm(l)
473
+ for l in self.resblocks:
474
+ l.remove_weight_norm()
475
+ remove_weight_norm(self.conv_pre)
476
+ remove_weight_norm(self.conv_post)
477
+
478
+
479
+
480
+ def stft(x, fft_size, hop_size, win_length, window):
481
+ """Perform STFT and convert to magnitude spectrogram.
482
+ Args:
483
+ x (Tensor): Input signal tensor (B, T).
484
+ fft_size (int): FFT size.
485
+ hop_size (int): Hop size.
486
+ win_length (int): Window length.
487
+ window (str): Window function type.
488
+ Returns:
489
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
490
+ """
491
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
492
+ return_complex=True)
493
+ real = x_stft[..., 0]
494
+ imag = x_stft[..., 1]
495
+
496
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
497
+ return torch.abs(x_stft).transpose(2, 1)
498
+
499
+ class SpecDiscriminator(nn.Module):
500
+ """docstring for Discriminator."""
501
+
502
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
503
+ super(SpecDiscriminator, self).__init__()
504
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
505
+ self.fft_size = fft_size
506
+ self.shift_size = shift_size
507
+ self.win_length = win_length
508
+ self.window = getattr(torch, window)(win_length)
509
+ self.discriminators = nn.ModuleList([
510
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
511
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
512
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
513
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
514
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
515
+ ])
516
+
517
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
518
+
519
+ def forward(self, y):
520
+
521
+ fmap = []
522
+ y = y.squeeze(1)
523
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
524
+ y = y.unsqueeze(1)
525
+ for i, d in enumerate(self.discriminators):
526
+ y = d(y)
527
+ y = F.leaky_relu(y, LRELU_SLOPE)
528
+ fmap.append(y)
529
+
530
+ y = self.out(y)
531
+ fmap.append(y)
532
+
533
+ return torch.flatten(y, 1, -1), fmap
534
+
535
+ # class MultiResSpecDiscriminator(torch.nn.Module):
536
+
537
+ # def __init__(self,
538
+ # fft_sizes=[1024, 2048, 512],
539
+ # hop_sizes=[120, 240, 50],
540
+ # win_lengths=[600, 1200, 240],
541
+ # window="hann_window"):
542
+
543
+ # super(MultiResSpecDiscriminator, self).__init__()
544
+ # self.discriminators = nn.ModuleList([
545
+ # SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
546
+ # SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
547
+ # SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
548
+ # ])
549
+
550
+ # def forward(self, y, y_hat):
551
+ # y_d_rs = []
552
+ # y_d_gs = []
553
+ # fmap_rs = []
554
+ # fmap_gs = []
555
+ # for i, d in enumerate(self.discriminators):
556
+ # y_d_r, fmap_r = d(y)
557
+ # y_d_g, fmap_g = d(y_hat)
558
+ # y_d_rs.append(y_d_r)
559
+ # fmap_rs.append(fmap_r)
560
+ # y_d_gs.append(y_d_g)
561
+ # fmap_gs.append(fmap_g)
562
+
563
+ # return y_d_rs, y_d_gs, fmap_rs, fmap_gs
564
+
565
+
566
+ class DiscriminatorP(torch.nn.Module):
567
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
568
+ super(DiscriminatorP, self).__init__()
569
+ self.period = period
570
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
571
+ self.convs = nn.ModuleList([
572
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
573
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
574
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
575
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
576
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
577
+ ])
578
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
579
+
580
+ def forward(self, x):
581
+ fmap = []
582
+
583
+ # 1d to 2d
584
+ b, c, t = x.shape
585
+ if t % self.period != 0: # pad first
586
+ n_pad = self.period - (t % self.period)
587
+ x = F.pad(x, (0, n_pad), "reflect")
588
+ t = t + n_pad
589
+ x = x.view(b, c, t // self.period, self.period)
590
+
591
+ for l in self.convs:
592
+ x = l(x)
593
+ x = F.leaky_relu(x, LRELU_SLOPE)
594
+ fmap.append(x)
595
+ x = self.conv_post(x)
596
+ fmap.append(x)
597
+ x = torch.flatten(x, 1, -1)
598
+
599
+ return x, fmap
600
+
601
+
602
+ class MultiPeriodDiscriminator(torch.nn.Module):
603
+ def __init__(self):
604
+ super(MultiPeriodDiscriminator, self).__init__()
605
+ self.discriminators = nn.ModuleList([
606
+ DiscriminatorP(2),
607
+ DiscriminatorP(3),
608
+ DiscriminatorP(5),
609
+ DiscriminatorP(7),
610
+ DiscriminatorP(11),
611
+ ])
612
+
613
+ def forward(self, y, y_hat):
614
+ y_d_rs = []
615
+ y_d_gs = []
616
+ fmap_rs = []
617
+ fmap_gs = []
618
+ for i, d in enumerate(self.discriminators):
619
+ y_d_r, fmap_r = d(y)
620
+ y_d_g, fmap_g = d(y_hat)
621
+ y_d_rs.append(y_d_r)
622
+ fmap_rs.append(fmap_r)
623
+ y_d_gs.append(y_d_g)
624
+ fmap_gs.append(fmap_g)
625
+
626
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
627
+
628
+
629
+ class DiscriminatorS(torch.nn.Module):
630
+ def __init__(self, use_spectral_norm=False):
631
+ super(DiscriminatorS, self).__init__()
632
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
633
+ self.convs = nn.ModuleList([
634
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
635
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
636
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
637
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
638
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
639
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
640
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
641
+ ])
642
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
643
+
644
+ def forward(self, x):
645
+ fmap = []
646
+ for l in self.convs:
647
+ x = l(x)
648
+ x = F.leaky_relu(x, LRELU_SLOPE)
649
+ fmap.append(x)
650
+ x = self.conv_post(x)
651
+ fmap.append(x)
652
+ x = torch.flatten(x, 1, -1)
653
+
654
+ return x, fmap
655
+
656
+
657
+ class MultiScaleDiscriminator(torch.nn.Module):
658
+ def __init__(self):
659
+ super(MultiScaleDiscriminator, self).__init__()
660
+ self.discriminators = nn.ModuleList([
661
+ DiscriminatorS(use_spectral_norm=True),
662
+ DiscriminatorS(),
663
+ DiscriminatorS(),
664
+ ])
665
+ self.meanpools = nn.ModuleList([
666
+ AvgPool1d(4, 2, padding=2),
667
+ AvgPool1d(4, 2, padding=2)
668
+ ])
669
+
670
+ def forward(self, y, y_hat):
671
+ y_d_rs = []
672
+ y_d_gs = []
673
+ fmap_rs = []
674
+ fmap_gs = []
675
+ for i, d in enumerate(self.discriminators):
676
+ if i != 0:
677
+ y = self.meanpools[i-1](y)
678
+ y_hat = self.meanpools[i-1](y_hat)
679
+ y_d_r, fmap_r = d(y)
680
+ y_d_g, fmap_g = d(y_hat)
681
+ y_d_rs.append(y_d_r)
682
+ fmap_rs.append(fmap_r)
683
+ y_d_gs.append(y_d_g)
684
+ fmap_gs.append(fmap_g)
685
+
686
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
687
+
688
+
689
+
690
+
691
+
692
+ ########################### from ringformer
693
+
694
+ multiscale_subband_cfg = {
695
+ "hop_lengths": [1024, 512, 512], # Doubled to maintain similar time resolution
696
+ "sampling_rate": 44100, # New sampling rate
697
+ "filters": 32, # Kept same as it controls initial feature dimension
698
+ "max_filters": 1024, # Kept same as it's a maximum limit
699
+ "filters_scale": 1, # Kept same as it's a scaling factor
700
+ "dilations": [1, 2, 4], # Kept same as they control receptive field growth
701
+ "in_channels": 1, # Kept same (mono audio)
702
+ "out_channels": 1, # Kept same (mono audio)
703
+ "n_octaves": [10, 10, 10], # Increased by 1 to handle higher frequency range
704
+ "bins_per_octaves": [24, 36, 48], # Kept same as they control frequency resolution
705
+ }
706
+
707
+
708
+
709
+ # multiscale_subband_cfg = {
710
+ # "hop_lengths": [512, 256, 256],
711
+ # "sampling_rate": 24000,
712
+ # "filters": 32,
713
+ # "max_filters": 1024,
714
+ # "filters_scale": 1,
715
+ # "dilations": [1, 2, 4],
716
+ # "in_channels": 1,
717
+ # "out_channels": 1,
718
+ # "n_octaves": [9, 9, 9],
719
+ # "bins_per_octaves": [24, 36, 48],
720
+ # }
721
+
722
+ class DiscriminatorCQT(nn.Module):
723
+ def __init__(self, cfg, hop_length, n_octaves, bins_per_octave):
724
+ super(DiscriminatorCQT, self).__init__()
725
+ self.cfg = cfg
726
+
727
+ self.filters = cfg.filters
728
+ self.max_filters = cfg.max_filters
729
+ self.filters_scale = cfg.filters_scale
730
+ self.kernel_size = (3, 9)
731
+ self.dilations = cfg.dilations
732
+ self.stride = (1, 2)
733
+
734
+ self.in_channels = cfg.in_channels
735
+ self.out_channels = cfg.out_channels
736
+ self.fs = cfg.sampling_rate
737
+ self.hop_length = hop_length
738
+ self.n_octaves = n_octaves
739
+ self.bins_per_octave = bins_per_octave
740
+
741
+ self.cqt_transform = features.cqt.CQT2010v2(
742
+ sr=self.fs * 2,
743
+ hop_length=self.hop_length,
744
+ n_bins=self.bins_per_octave * self.n_octaves,
745
+ bins_per_octave=self.bins_per_octave,
746
+ output_format="Complex",
747
+ pad_mode="constant",
748
+ )
749
+
750
+ self.conv_pres = nn.ModuleList()
751
+ for i in range(self.n_octaves):
752
+ self.conv_pres.append(
753
+ NormConv2d(
754
+ self.in_channels * 2,
755
+ self.in_channels * 2,
756
+ kernel_size=self.kernel_size,
757
+ padding=get_2d_padding(self.kernel_size),
758
+ )
759
+ )
760
+
761
+ self.convs = nn.ModuleList()
762
+
763
+ self.convs.append(
764
+ NormConv2d(
765
+ self.in_channels * 2,
766
+ self.filters,
767
+ kernel_size=self.kernel_size,
768
+ padding=get_2d_padding(self.kernel_size),
769
+ )
770
+ )
771
+
772
+ in_chs = min(self.filters_scale * self.filters, self.max_filters)
773
+ for i, dilation in enumerate(self.dilations):
774
+ out_chs = min(
775
+ (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
776
+ )
777
+ self.convs.append(
778
+ NormConv2d(
779
+ in_chs,
780
+ out_chs,
781
+ kernel_size=self.kernel_size,
782
+ stride=self.stride,
783
+ dilation=(dilation, 1),
784
+ padding=get_2d_padding(self.kernel_size, (dilation, 1)),
785
+ norm="weight_norm",
786
+ )
787
+ )
788
+ in_chs = out_chs
789
+ out_chs = min(
790
+ (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
791
+ self.max_filters,
792
+ )
793
+ self.convs.append(
794
+ NormConv2d(
795
+ in_chs,
796
+ out_chs,
797
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
798
+ padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
799
+ norm="weight_norm",
800
+ )
801
+ )
802
+
803
+ self.conv_post = NormConv2d(
804
+ out_chs,
805
+ self.out_channels,
806
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
807
+ padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
808
+ norm="weight_norm",
809
+ )
810
+
811
+ self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE)
812
+ self.resample = torchaudio.transforms.Resample(
813
+ orig_freq=self.fs, new_freq=self.fs * 2
814
+ )
815
+
816
+ def forward(self, x):
817
+ fmap = []
818
+
819
+ x = self.resample(x)
820
+
821
+ z = self.cqt_transform(x)
822
+
823
+ z_amplitude = z[:, :, :, 0].unsqueeze(1)
824
+ z_phase = z[:, :, :, 1].unsqueeze(1)
825
+
826
+ z = torch.cat([z_amplitude, z_phase], dim=1)
827
+ z = rearrange(z, "b c w t -> b c t w")
828
+
829
+ latent_z = []
830
+ for i in range(self.n_octaves):
831
+ latent_z.append(
832
+ self.conv_pres[i](
833
+ z[
834
+ :,
835
+ :,
836
+ :,
837
+ i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
838
+ ]
839
+ )
840
+ )
841
+ latent_z = torch.cat(latent_z, dim=-1)
842
+
843
+ for i, l in enumerate(self.convs):
844
+ latent_z = l(latent_z)
845
+
846
+ latent_z = self.activation(latent_z)
847
+ fmap.append(latent_z)
848
+
849
+ latent_z = self.conv_post(latent_z)
850
+
851
+ return latent_z, fmap
852
+
853
+
854
+
855
+ class MultiScaleSubbandCQTDiscriminator(nn.Module): # replacing "MultiResSpecDiscriminator"
856
+ def __init__(self):
857
+ super(MultiScaleSubbandCQTDiscriminator, self).__init__()
858
+ cfg = Munch(multiscale_subband_cfg)
859
+ self.cfg = cfg
860
+ self.discriminators = nn.ModuleList(
861
+ [
862
+ DiscriminatorCQT(
863
+ cfg,
864
+ hop_length=cfg.hop_lengths[i],
865
+ n_octaves=cfg.n_octaves[i],
866
+ bins_per_octave=cfg.bins_per_octaves[i],
867
+ )
868
+ for i in range(len(cfg.hop_lengths))
869
+ ]
870
+ )
871
+
872
+ def forward(self, y, y_hat):
873
+ y_d_rs = []
874
+ y_d_gs = []
875
+ fmap_rs = []
876
+ fmap_gs = []
877
+
878
+ for disc in self.discriminators:
879
+ y_d_r, fmap_r = disc(y)
880
+ y_d_g, fmap_g = disc(y_hat)
881
+ y_d_rs.append(y_d_r)
882
+ fmap_rs.append(fmap_r)
883
+ y_d_gs.append(y_d_g)
884
+ fmap_gs.append(fmap_g)
885
+
886
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
887
+
888
+
889
+
890
+ #############################
891
+
892
+
893
+
894
+ def feature_loss(fmap_r, fmap_g):
895
+ loss = 0
896
+ for dr, dg in zip(fmap_r, fmap_g):
897
+ for rl, gl in zip(dr, dg):
898
+ loss += torch.mean(torch.abs(rl - gl))
899
+
900
+ return loss*2
901
+
902
+
903
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
904
+ loss = 0
905
+ r_losses = []
906
+ g_losses = []
907
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
908
+ r_loss = torch.mean((1-dr)**2)
909
+ g_loss = torch.mean(dg**2)
910
+ loss += (r_loss + g_loss)
911
+ r_losses.append(r_loss.item())
912
+ g_losses.append(g_loss.item())
913
+
914
+ return loss, r_losses, g_losses
915
+
916
+
917
+ def generator_loss(disc_outputs):
918
+ loss = 0
919
+ gen_losses = []
920
+ for dg in disc_outputs:
921
+ l = torch.mean((1-dg)**2)
922
+ gen_losses.append(l)
923
+ loss += l
924
+
925
+ return loss, gen_losses
926
+
927
+ def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
928
+ loss = 0
929
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
930
+ tau = 0.04
931
+ m_DG = torch.median((dr-dg))
932
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
933
+ loss += tau - F.relu(tau - L_rel)
934
+ return loss
935
+
936
+ def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
937
+ loss = 0
938
+ for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
939
+ tau = 0.04
940
+ m_DG = torch.median((dr-dg))
941
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
942
+ loss += tau - F.relu(tau - L_rel)
943
+ return loss
RingFormer/norm2d.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ #################### Norm2D for Discriminators ####################
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import einops
11
+ from torch.nn.utils import spectral_norm, weight_norm
12
+
13
+ CONV_NORMALIZATIONS = frozenset(
14
+ [
15
+ "none",
16
+ "weight_norm",
17
+ "spectral_norm",
18
+ "time_layer_norm",
19
+ "layer_norm",
20
+ "time_group_norm",
21
+ ]
22
+ )
23
+
24
+
25
+ class ConvLayerNorm(nn.LayerNorm):
26
+ """
27
+ Convolution-friendly LayerNorm that moves channels to last dimensions
28
+ before running the normalization and moves them back to original position right after.
29
+ """
30
+
31
+ def __init__(self, normalized_shape, **kwargs):
32
+ super().__init__(normalized_shape, **kwargs)
33
+
34
+ def forward(self, x):
35
+ x = einops.rearrange(x, "b ... t -> b t ...")
36
+ x = super().forward(x)
37
+ x = einops.rearrange(x, "b t ... -> b ... t")
38
+ return
39
+
40
+
41
+ def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
42
+ assert norm in CONV_NORMALIZATIONS
43
+ if norm == "weight_norm":
44
+ return weight_norm(module)
45
+ elif norm == "spectral_norm":
46
+ return spectral_norm(module)
47
+ else:
48
+ # We already check was in CONV_NORMALIZATION, so any other choice
49
+ # doesn't need reparametrization.
50
+ return module
51
+
52
+
53
+ def get_norm_module(
54
+ module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
55
+ ) -> nn.Module:
56
+ """Return the proper normalization module. If causal is True, this will ensure the returned
57
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
58
+ """
59
+ assert norm in CONV_NORMALIZATIONS
60
+ if norm == "layer_norm":
61
+ assert isinstance(module, nn.modules.conv._ConvNd)
62
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
63
+ elif norm == "time_group_norm":
64
+ if causal:
65
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
66
+ assert isinstance(module, nn.modules.conv._ConvNd)
67
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
68
+ else:
69
+ return nn.Identity()
70
+
71
+
72
+ class NormConv2d(nn.Module):
73
+ """Wrapper around Conv2d and normalization applied to this conv
74
+ to provide a uniform interface across normalization approaches.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ *args,
80
+ norm: str = "none",
81
+ norm_kwargs={},
82
+ **kwargs,
83
+ ):
84
+ super().__init__()
85
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
86
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
87
+ self.norm_type = norm
88
+
89
+ def forward(self, x):
90
+ x = self.conv(x)
91
+ x = self.norm(x)
92
+ return x
RingFormer/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ librosa
4
+ scipy==1.4.1
5
+ tensorboard==2.0
6
+ soundfile
7
+ matplotlib==3.1.3
8
+ nnaudio
9
+ ring-attention-pytorch
10
+ --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/triton-nightly
RingFormer/stft.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+ Copyright (c) 2017, Prem Seetharaman
4
+ All rights reserved.
5
+ * Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+ * Redistributions of source code must retain the above copyright notice,
8
+ this list of conditions and the following disclaimer.
9
+ * Redistributions in binary form must reproduce the above copyright notice, this
10
+ list of conditions and the following disclaimer in the
11
+ documentation and/or other materials provided with the distribution.
12
+ * Neither the name of the copyright holder nor the names of its
13
+ contributors may be used to endorse or promote products derived from this
14
+ software without specific prior written permission.
15
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
19
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ """
26
+
27
+ import torch
28
+ import numpy as np
29
+ import torch.nn.functional as F
30
+ from torch.autograd import Variable
31
+ from scipy.signal import get_window
32
+ from librosa.util import pad_center, tiny
33
+ import librosa.util as librosa_util
34
+
35
+
36
+ def window_sumsquare(
37
+ window,
38
+ n_frames,
39
+ hop_length=200,
40
+ win_length=800,
41
+ n_fft=800,
42
+ dtype=np.float32,
43
+ norm=None,
44
+ ):
45
+ """
46
+ # from librosa 0.6
47
+ Compute the sum-square envelope of a window function at a given hop length.
48
+ This is used to estimate modulation effects induced by windowing
49
+ observations in short-time fourier transforms.
50
+ Parameters
51
+ ----------
52
+ window : string, tuple, number, callable, or list-like
53
+ Window specification, as in `get_window`
54
+ n_frames : int > 0
55
+ The number of analysis frames
56
+ hop_length : int > 0
57
+ The number of samples to advance between frames
58
+ win_length : [optional]
59
+ The length of the window function. By default, this matches `n_fft`.
60
+ n_fft : int > 0
61
+ The length of each analysis frame.
62
+ dtype : np.dtype
63
+ The data type of the output
64
+ Returns
65
+ -------
66
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
67
+ The sum-squared envelope of the window function
68
+ """
69
+ if win_length is None:
70
+ win_length = n_fft
71
+
72
+ n = n_fft + hop_length * (n_frames - 1)
73
+ x = np.zeros(n, dtype=dtype)
74
+
75
+ # Compute the squared window at the desired length
76
+ win_sq = get_window(window, win_length, fftbins=True)
77
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
78
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
79
+
80
+ # Fill the envelope
81
+ for i in range(n_frames):
82
+ sample = i * hop_length
83
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
84
+ return x
85
+
86
+
87
+ def stft(x, fft_size, hop_size, win_length, window):
88
+ """Perform STFT and convert to magnitude spectrogram.
89
+ Args:
90
+ x (Tensor): Input signal tensor (B, T).
91
+ fft_size (int): FFT size.
92
+ hop_size (int): Hop size.
93
+ win_length (int): Window length.
94
+ window (str): Window function type.
95
+ Returns:
96
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
97
+ """
98
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
99
+ real = x_stft[..., 0]
100
+ imag = x_stft[..., 1]
101
+
102
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
103
+ return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
104
+
105
+
106
+ class STFT(torch.nn.Module):
107
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
108
+
109
+ def __init__(
110
+ self, filter_length=800, hop_length=200, win_length=800, window="hann"
111
+ ):
112
+ super(STFT, self).__init__()
113
+ self.filter_length = filter_length
114
+ self.hop_length = hop_length
115
+ self.win_length = win_length
116
+ self.window = window
117
+ self.forward_transform = None
118
+ scale = self.filter_length / self.hop_length
119
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
120
+
121
+ cutoff = int((self.filter_length / 2 + 1))
122
+ fourier_basis = np.vstack(
123
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
124
+ )
125
+
126
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
127
+ inverse_basis = torch.FloatTensor(
128
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
129
+ )
130
+
131
+ if window is not None:
132
+ assert filter_length >= win_length
133
+ # get window and zero center pad it to filter_length
134
+ fft_window = get_window(window, win_length, fftbins=True)
135
+ fft_window = pad_center(fft_window, filter_length)
136
+ fft_window = torch.from_numpy(fft_window).float()
137
+
138
+ # window the bases
139
+ forward_basis *= fft_window
140
+ inverse_basis *= fft_window
141
+
142
+ self.register_buffer("forward_basis", forward_basis.float())
143
+ self.register_buffer("inverse_basis", inverse_basis.float())
144
+
145
+ def transform(self, input_data):
146
+ num_batches = input_data.size(0)
147
+ num_samples = input_data.size(1)
148
+
149
+ self.num_samples = num_samples
150
+
151
+ # similar to librosa, reflect-pad the input
152
+ input_data = input_data.view(num_batches, 1, num_samples)
153
+ input_data = F.pad(
154
+ input_data.unsqueeze(1),
155
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
156
+ mode="reflect",
157
+ )
158
+ input_data = input_data.squeeze(1)
159
+
160
+ forward_transform = F.conv1d(
161
+ input_data,
162
+ Variable(self.forward_basis, requires_grad=False),
163
+ stride=self.hop_length,
164
+ padding=0,
165
+ )
166
+
167
+ cutoff = int((self.filter_length / 2) + 1)
168
+ real_part = forward_transform[:, :cutoff, :]
169
+ imag_part = forward_transform[:, cutoff:, :]
170
+
171
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
172
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
173
+
174
+ return magnitude, phase
175
+
176
+ def inverse(self, magnitude, phase):
177
+ recombine_magnitude_phase = torch.cat(
178
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
179
+ )
180
+
181
+ inverse_transform = F.conv_transpose1d(
182
+ recombine_magnitude_phase,
183
+ Variable(self.inverse_basis, requires_grad=False),
184
+ stride=self.hop_length,
185
+ padding=0,
186
+ )
187
+
188
+ if self.window is not None:
189
+ window_sum = window_sumsquare(
190
+ self.window,
191
+ magnitude.size(-1),
192
+ hop_length=self.hop_length,
193
+ win_length=self.win_length,
194
+ n_fft=self.filter_length,
195
+ dtype=np.float32,
196
+ )
197
+ # remove modulation effects
198
+ approx_nonzero_indices = torch.from_numpy(
199
+ np.where(window_sum > tiny(window_sum))[0]
200
+ )
201
+ window_sum = torch.autograd.Variable(
202
+ torch.from_numpy(window_sum), requires_grad=False
203
+ )
204
+ window_sum = (
205
+ window_sum.to(inverse_transform.device())
206
+ if magnitude.is_cuda
207
+ else window_sum
208
+ )
209
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
210
+ approx_nonzero_indices
211
+ ]
212
+
213
+ # scale by hop ratio
214
+ inverse_transform *= float(self.filter_length) / self.hop_length
215
+
216
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
217
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
218
+
219
+ return inverse_transform
220
+
221
+ def forward(self, input_data):
222
+ self.magnitude, self.phase = self.transform(input_data)
223
+ reconstruction = self.inverse(self.magnitude, self.phase)
224
+ return reconstruction
225
+
226
+
227
+
228
+ class TorchSTFT(torch.nn.Module):
229
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
230
+ super().__init__()
231
+ self.filter_length = filter_length
232
+ self.hop_length = hop_length
233
+ self.win_length = win_length
234
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
235
+
236
+ def transform(self, input_data):
237
+ forward_transform = torch.stft(
238
+ input_data,
239
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
240
+ return_complex=True)
241
+
242
+ return torch.abs(forward_transform), torch.angle(forward_transform)
243
+
244
+ def inverse(self, magnitude, phase):
245
+ inverse_transform = torch.istft(
246
+ magnitude * torch.exp(phase * 1j),
247
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
248
+
249
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
250
+
251
+ def forward(self, input_data):
252
+ self.magnitude, self.phase = self.transform(input_data)
253
+ reconstruction = self.inverse(self.magnitude, self.phase)
254
+ return reconstruction
RingFormer/train.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import warnings
2
+ # warnings.simplefilter(action='ignore', category=FutureWarning)
3
+ # import itertools
4
+ # import os
5
+ # import time
6
+ # import argparse
7
+ # import json
8
+ # import torch
9
+ # import torch.nn.functional as F
10
+ # from torch.utils.tensorboard import SummaryWriter
11
+ # from torch.utils.data import DistributedSampler, DataLoader
12
+ # import torch.multiprocessing as mp
13
+ # from torch.distributed import init_process_group
14
+ # from torch.nn.parallel import DistributedDataParallel
15
+ # from env import AttrDict, build_env
16
+ # from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
17
+ # from models import Generator, MultiPeriodDiscriminator, feature_loss, generator_loss,\
18
+ # discriminator_loss, discriminator_TPRLS_loss, generator_TPRLS_loss, MultiScaleSubbandCQTDiscriminator
19
+ # from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
20
+ # from stft import TorchSTFT
21
+ # from Utils.JDC.model import JDCNet
22
+
23
+ # torch.backends.cudnn.benchmark = True
24
+
25
+
26
+ # def train(rank, a, h):
27
+ # if h.num_gpus > 1:
28
+ # init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
29
+ # world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
30
+
31
+ # torch.cuda.manual_seed(h.seed)
32
+ # device = torch.device('cuda:{:d}'.format(rank))
33
+
34
+ # F0_model = JDCNet(num_class=1, seq_len=192)
35
+ # params = torch.load(h.F0_path)['net']
36
+ # F0_model.load_state_dict(params)
37
+
38
+ # generator = Generator(h, F0_model, device=device).to(device)
39
+ # mpd = MultiPeriodDiscriminator().to(device)
40
+ # msd = MultiScaleSubbandCQTDiscriminator().to(device)
41
+ # stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
42
+
43
+ # if rank == 0:
44
+ # print(generator)
45
+ # os.makedirs(a.checkpoint_path, exist_ok=True)
46
+ # print("checkpoints directory : ", a.checkpoint_path)
47
+
48
+ # if os.path.isdir(a.checkpoint_path):
49
+ # cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
50
+ # cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
51
+
52
+ # steps = 0
53
+ # if cp_g is None or cp_do is None:
54
+ # state_dict_do = None
55
+ # last_epoch = -1
56
+ # else:
57
+ # state_dict_g = load_checkpoint(cp_g, device)
58
+ # state_dict_do = load_checkpoint(cp_do, device)
59
+ # generator.load_state_dict(state_dict_g['generator'])
60
+ # mpd.load_state_dict(state_dict_do['mpd'])
61
+ # msd.load_state_dict(state_dict_do['msd'])
62
+ # steps = state_dict_do['steps'] + 1
63
+ # last_epoch = state_dict_do['epoch']
64
+
65
+ # if h.num_gpus > 1:
66
+ # generator = DistributedDataParallel(generator, device_ids=[rank], find_unused_parameters=True).to(device)
67
+ # mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
68
+ # msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
69
+
70
+ # optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
71
+ # optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
72
+ # h.learning_rate, betas=[h.adam_b1, h.adam_b2])
73
+
74
+ # if state_dict_do is not None:
75
+ # optim_g.load_state_dict(state_dict_do['optim_g'])
76
+ # optim_d.load_state_dict(state_dict_do['optim_d'])
77
+
78
+ # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
79
+ # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
80
+
81
+ # training_filelist, validation_filelist = get_dataset_filelist(a)
82
+
83
+ # trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
84
+ # h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
85
+ # shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
86
+ # fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
87
+
88
+ # train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
89
+
90
+ # train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
91
+ # sampler=train_sampler,
92
+ # batch_size=h.batch_size,
93
+ # pin_memory=True,
94
+ # drop_last=True)
95
+
96
+
97
+
98
+
99
+ # if rank == 0:
100
+ # validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
101
+ # h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
102
+ # fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
103
+ # base_mels_path=a.input_mels_dir)
104
+ # validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
105
+ # sampler=None,
106
+ # batch_size=1,
107
+ # pin_memory=True,
108
+ # drop_last=True)
109
+
110
+ # sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
111
+
112
+ # generator.train()
113
+ # mpd.train()
114
+ # msd.train()
115
+ # for epoch in range(max(0, last_epoch), a.training_epochs):
116
+ # if rank == 0:
117
+ # start = time.time()
118
+ # print("Epoch: {}".format(epoch+1))
119
+
120
+ # if h.num_gpus > 1:
121
+ # train_sampler.set_epoch(epoch)
122
+
123
+ # for i, batch in enumerate(train_loader):
124
+ # if rank == 0:
125
+ # start_b = time.time()
126
+ # x, y, _, y_mel = batch
127
+ # x = torch.autograd.Variable(x.to(device, non_blocking=True))
128
+ # y = torch.autograd.Variable(y.to(device, non_blocking=True))
129
+ # y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
130
+ # y = y.unsqueeze(1)
131
+ # # y_g_hat = generator(x)
132
+ # spec, phase = generator(x)
133
+
134
+ # y_g_hat = stft.inverse(spec, phase)
135
+
136
+ # y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
137
+ # h.fmin, h.fmax_for_loss)
138
+
139
+ # optim_d.zero_grad()
140
+
141
+ # # MPD
142
+ # y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
143
+ # loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
144
+ # loss_disc_f += discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
145
+
146
+ # # MSD
147
+ # y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
148
+ # loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
149
+ # loss_disc_s += discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
150
+
151
+ # loss_disc_all = loss_disc_s + loss_disc_f
152
+
153
+ # loss_disc_all.backward()
154
+ # optim_d.step()
155
+
156
+ # # Generator
157
+ # optim_g.zero_grad()
158
+
159
+ # # L1 Mel-Spectrogram Loss
160
+ # loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
161
+
162
+ # y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
163
+ # y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
164
+ # loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
165
+ # loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
166
+ # loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
167
+ # loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
168
+
169
+ # loss_gen_f += generator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
170
+ # loss_gen_s += generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
171
+
172
+ # loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
173
+
174
+ # loss_gen_all.backward()
175
+ # optim_g.step()
176
+
177
+ # if rank == 0:
178
+ # # STDOUT logging
179
+ # if steps % a.stdout_interval == 0:
180
+ # with torch.no_grad():
181
+ # mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
182
+
183
+ # print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
184
+ # format(steps, loss_gen_all, mel_error, time.time() - start_b))
185
+
186
+ # # checkpointing
187
+ # if steps % a.checkpoint_interval == 0 and steps != 0:
188
+ # checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
189
+ # save_checkpoint(checkpoint_path,
190
+ # {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
191
+ # checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
192
+ # save_checkpoint(checkpoint_path,
193
+ # {'mpd': (mpd.module if h.num_gpus > 1
194
+ # else mpd).state_dict(),
195
+ # 'msd': (msd.module if h.num_gpus > 1
196
+ # else msd).state_dict(),
197
+ # 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
198
+ # 'epoch': epoch})
199
+
200
+ # # Tensorboard summary logging
201
+ # if steps % a.summary_interval == 0:
202
+ # sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
203
+ # sw.add_scalar("training/mel_spec_error", mel_error, steps)
204
+
205
+ # # Validation
206
+ # if steps % a.validation_interval == 0: # and steps != 0:
207
+ # generator.eval()
208
+ # torch.cuda.empty_cache()
209
+ # val_err_tot = 0
210
+ # with torch.no_grad():
211
+ # for j, batch in enumerate(validation_loader):
212
+ # x, y, _, y_mel = batch
213
+ # # y_g_hat = generator(x.to(device))
214
+ # spec, phase = generator(x.to(device))
215
+
216
+ # y_g_hat = stft.inverse(spec, phase)
217
+
218
+ # y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
219
+ # y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
220
+ # h.hop_size, h.win_size,
221
+ # h.fmin, h.fmax_for_loss)
222
+ # val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
223
+
224
+ # if j <= 4:
225
+ # if steps == 0:
226
+ # sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
227
+ # sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
228
+
229
+ # sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
230
+ # y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
231
+ # h.sampling_rate, h.hop_size, h.win_size,
232
+ # h.fmin, h.fmax)
233
+ # sw.add_figure('generated/y_hat_spec_{}'.format(j),
234
+ # plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
235
+
236
+ # val_err = val_err_tot / (j+1)
237
+ # sw.add_scalar("validation/mel_spec_error", val_err, steps)
238
+
239
+ # generator.train()
240
+
241
+ # steps += 1
242
+
243
+ # scheduler_g.step()
244
+ # scheduler_d.step()
245
+
246
+ # if rank == 0:
247
+ # print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
248
+
249
+
250
+ # def main():
251
+ # print('Initializing Training Process..')
252
+
253
+ # parser = argparse.ArgumentParser()
254
+
255
+ # parser.add_argument('--group_name', default=None)
256
+ # parser.add_argument('--input_wavs_dir', default='')
257
+ # parser.add_argument('--input_mels_dir', default='ft_dataset')
258
+ # # parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/training.txt')
259
+ # parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/validation.txt')
260
+ # parser.add_argument('--input_validation_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/validation.txt')
261
+ # parser.add_argument('--checkpoint_path', default='cp_ringformer_24')
262
+ # parser.add_argument('--config', default='config_v1.json')
263
+ # parser.add_argument('--training_epochs', default=3100, type=int)
264
+ # parser.add_argument('--stdout_interval', default=5, type=int)
265
+ # parser.add_argument('--checkpoint_interval', default=5000, type=int)
266
+ # parser.add_argument('--summary_interval', default=100, type=int)
267
+ # parser.add_argument('--validation_interval', default=1000, type=int)
268
+ # parser.add_argument('--fine_tuning', default=False, type=bool)
269
+
270
+ # a = parser.parse_args()
271
+
272
+ # with open(a.config) as f:
273
+ # data = f.read()
274
+
275
+ # json_config = json.loads(data)
276
+ # h = AttrDict(json_config)
277
+ # build_env(a.config, 'config.json', a.checkpoint_path)
278
+
279
+ # torch.manual_seed(h.seed)
280
+ # if torch.cuda.is_available():
281
+ # torch.cuda.manual_seed(h.seed)
282
+ # h.num_gpus = torch.cuda.device_count()
283
+ # h.batch_size = int(h.batch_size / h.num_gpus)
284
+ # print('Batch size per GPU :', h.batch_size)
285
+ # else:
286
+ # pass
287
+
288
+ # if h.num_gpus > 1:
289
+ # mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
290
+ # else:
291
+ # train(0, a, h)
292
+
293
+
294
+ # if __name__ == '__main__':
295
+ # main()
296
+
297
+
298
+ import warnings
299
+ warnings.simplefilter(action='ignore', category=FutureWarning)
300
+ import itertools
301
+ import os
302
+ import time
303
+ import argparse
304
+ import json
305
+ import torch
306
+ import torch.nn.functional as F
307
+ from torch.utils.tensorboard import SummaryWriter
308
+ from torch.utils.data import DistributedSampler, DataLoader
309
+ import torch.multiprocessing as mp
310
+ from torch.distributed import init_process_group
311
+ from torch.nn.parallel import DistributedDataParallel
312
+ from env import AttrDict, build_env
313
+ from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
314
+ from models import Generator, MultiPeriodDiscriminator, feature_loss, generator_loss,\
315
+ discriminator_loss, discriminator_TPRLS_loss, generator_TPRLS_loss, MultiScaleSubbandCQTDiscriminator
316
+ from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
317
+ from stft import TorchSTFT
318
+ from Utils.JDC.model import JDCNet
319
+ from accelerate import Accelerator
320
+ from accelerate.utils import LoggerType
321
+ from accelerate import DistributedDataParallelKwargs
322
+
323
+
324
+ torch.backends.cudnn.benchmark = True
325
+
326
+
327
+ def train(accelerator, a, h):
328
+ # if h.num_gpus > 1:
329
+ # init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
330
+ # world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
331
+
332
+
333
+
334
+ torch.cuda.manual_seed(h.seed)
335
+ device = accelerator.device
336
+
337
+ F0_model = JDCNet(num_class=1, seq_len=192)
338
+ params = torch.load(h.F0_path)['model']
339
+ F0_model.load_state_dict(params)
340
+
341
+ generator = Generator(h, F0_model).to(device)
342
+ mpd = MultiPeriodDiscriminator().to(device)
343
+ msd = MultiScaleSubbandCQTDiscriminator().to(device)
344
+ stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
345
+
346
+ with accelerator.main_process_first():
347
+ accelerator.print(generator)
348
+ os.makedirs(a.checkpoint_path, exist_ok=True)
349
+ accelerator.print("checkpoints directory : ", a.checkpoint_path)
350
+
351
+ if os.path.isdir(a.checkpoint_path):
352
+ cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
353
+ cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
354
+
355
+ steps = 0
356
+ if cp_g is None or cp_do is None:
357
+ state_dict_do = None
358
+ last_epoch = -1
359
+ else:
360
+ state_dict_g = load_checkpoint(cp_g, device)
361
+ state_dict_do = load_checkpoint(cp_do, device)
362
+ generator.load_state_dict(state_dict_g['generator'])
363
+ mpd.load_state_dict(state_dict_do['mpd'])
364
+ msd.load_state_dict(state_dict_do['msd'])
365
+ steps = state_dict_do['steps'] + 1
366
+ last_epoch = state_dict_do['epoch']
367
+
368
+ # if h.num_gpus > 1:
369
+ generator = accelerator.prepare(generator).to(device)
370
+ mpd = accelerator.prepare(mpd).to(device)
371
+ msd = accelerator.prepare(msd).to(device)
372
+
373
+ optim_g = accelerator.prepare(torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]))
374
+ optim_d = accelerator.prepare(torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
375
+ h.learning_rate, betas=[h.adam_b1, h.adam_b2]))
376
+
377
+
378
+
379
+ if state_dict_do is not None:
380
+ optim_g.load_state_dict(state_dict_do['optim_g'])
381
+ optim_d.load_state_dict(state_dict_do['optim_d'])
382
+
383
+ scheduler_g = accelerator.prepare(torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch))
384
+ scheduler_d = accelerator.prepare(torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch))
385
+
386
+ training_filelist, validation_filelist = get_dataset_filelist(a)
387
+
388
+ trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
389
+ h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
390
+ shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
391
+ fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
392
+
393
+ # train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
394
+
395
+ accelerator.print("bs", h.batch_size)
396
+
397
+ train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
398
+ # sampler=train_sampler,
399
+ batch_size=h.batch_size,
400
+ pin_memory=True,
401
+ drop_last=True)
402
+
403
+
404
+ train_loader = accelerator.prepare(train_loader)
405
+
406
+
407
+
408
+ validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
409
+ h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
410
+ fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
411
+ base_mels_path=a.input_mels_dir)
412
+
413
+
414
+
415
+ validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
416
+ sampler=None,
417
+ batch_size=1,
418
+ pin_memory=True,
419
+ drop_last=True)
420
+
421
+ sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
422
+
423
+ # validation_loader = accelerator.prepare(validation_loader)
424
+
425
+ generator.train()
426
+ mpd.train()
427
+ msd.train()
428
+ for epoch in range(max(0, last_epoch), a.training_epochs):
429
+ with accelerator.main_process_first():
430
+ start = time.time()
431
+ accelerator.print("Epoch: {}".format(epoch+1))
432
+
433
+ # if h.num_gpus > 1:
434
+ # train_sampler.set_epoch(epoch)
435
+
436
+ for i, batch in enumerate(train_loader):
437
+ # with accelerator.main_process_first():
438
+ start_b = time.time()
439
+ x, y, _, y_mel = batch
440
+ x = torch.autograd.Variable(x.to(device, non_blocking=True))
441
+ y = torch.autograd.Variable(y.to(device, non_blocking=True))
442
+ y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
443
+ y = y.unsqueeze(1)
444
+ # y_g_hat = generator(x)
445
+ spec, phase = generator(x)
446
+
447
+ y_g_hat = stft.inverse(spec, phase)
448
+
449
+ y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
450
+ h.fmin, h.fmax_for_loss)
451
+
452
+ optim_d.zero_grad()
453
+
454
+ # MPD
455
+ y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
456
+ loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
457
+ loss_disc_f += discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
458
+
459
+ # MSD
460
+ y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
461
+ loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
462
+ loss_disc_s += discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
463
+
464
+ loss_disc_all = loss_disc_s + loss_disc_f
465
+
466
+ accelerator.backward(loss_disc_all)
467
+ # accelerator.clip_grad_norm_(mpd.parameters(), max_norm=10.0)
468
+ # accelerator.clip_grad_norm_(msd.parameters(), max_norm=10.0)
469
+
470
+ # loss_disc_all.backward()
471
+ optim_d.step()
472
+
473
+ # Generator
474
+ optim_g.zero_grad()
475
+
476
+ # L1 Mel-Spectrogram Loss
477
+ loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
478
+
479
+ y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
480
+ y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
481
+ loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
482
+ loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
483
+ loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
484
+ loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
485
+
486
+ loss_gen_f += generator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
487
+ loss_gen_s += generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
488
+
489
+ loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
490
+
491
+
492
+ accelerator.backward(loss_gen_all)
493
+
494
+ # accelerator.clip_grad_norm_(generator.parameters(), max_norm=10.0)
495
+
496
+ # loss_gen_all.backward()
497
+ optim_g.step()
498
+
499
+ # accelerator.print('done +',i)
500
+
501
+
502
+ # STDOUT logging
503
+ if steps % a.stdout_interval == 0:
504
+
505
+ with accelerator.main_process_first():
506
+
507
+ with torch.no_grad():
508
+ mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
509
+
510
+ accelerator.print('Steps : {:d}, Gen Loss Total : {:4.3f}, Disc Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
511
+ format(steps, loss_gen_all, loss_disc_all, mel_error, time.time() - start_b))
512
+
513
+ # checkpointing
514
+ # In the main code:
515
+ # if steps % a.checkpoint_interval == 0 and steps != 0:
516
+ # # with accelerator.main_process_first():
517
+ # # Unwrap the generator
518
+ # unwrapped_generator = accelerator.unwrap_model(generator)
519
+ # checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
520
+ # save_checkpoint(checkpoint_path,
521
+ # {'generator': unwrapped_generator.state_dict()})
522
+
523
+ # # Unwrap discriminators
524
+ # unwrapped_mpd = accelerator.unwrap_model(mpd)
525
+ # unwrapped_msd = accelerator.unwrap_model(msd)
526
+
527
+ # checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
528
+ # save_checkpoint(checkpoint_path,
529
+ # {'mpd': unwrapped_mpd.state_dict(),
530
+ # 'msd': unwrapped_msd.state_dict(),
531
+ # 'optim_g': optim_g.state_dict(),
532
+ # 'optim_d': optim_d.state_dict(),
533
+ # 'steps': steps,
534
+ # 'epoch': epoch})
535
+
536
+
537
+
538
+
539
+ if steps % a.checkpoint_interval == 0 and steps != 0:
540
+
541
+ with accelerator.main_process_first():
542
+ checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
543
+ save_checkpoint(checkpoint_path,
544
+ {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
545
+ checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
546
+ save_checkpoint(checkpoint_path,
547
+ {'mpd': (mpd.module if h.num_gpus > 1
548
+ else mpd).state_dict(),
549
+ 'msd': (msd.module if h.num_gpus > 1
550
+ else msd).state_dict(),
551
+ 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
552
+ 'epoch': epoch})
553
+
554
+ # Tensorboard summary logging
555
+ if steps % a.summary_interval == 0:
556
+ sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
557
+ sw.add_scalar("training/mel_spec_error", mel_error, steps)
558
+
559
+ # Validation
560
+ if steps % a.validation_interval == 0: # and steps != 0:
561
+
562
+ # with accelerator.main_process_first():
563
+ accelerator.print('evalution...')
564
+
565
+
566
+
567
+ generator.eval()
568
+ torch.cuda.empty_cache()
569
+ val_err_tot = 0
570
+ with torch.no_grad():
571
+ for j, batch in enumerate(validation_loader):
572
+ x, y, _, y_mel = batch
573
+ # y_g_hat = generator(x.to(device))
574
+ spec, phase = generator(x.to('cuda'))
575
+
576
+ y_g_hat = stft.inverse(spec, phase)
577
+
578
+
579
+
580
+ y_mel = torch.autograd.Variable(y_mel.to('cuda', non_blocking=True))
581
+ y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
582
+ h.hop_size, h.win_size,
583
+ h.fmin, h.fmax_for_loss)
584
+
585
+ accelerator.print('done: ',j)
586
+
587
+
588
+ val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
589
+
590
+ if j <= 4:
591
+ if steps == 0:
592
+ sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
593
+ sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
594
+
595
+
596
+ sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
597
+ y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
598
+ h.sampling_rate, h.hop_size, h.win_size,
599
+ h.fmin, h.fmax)
600
+
601
+
602
+ sw.add_figure('generated/y_hat_spec_{}'.format(j),
603
+ plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
604
+
605
+
606
+
607
+ val_err = val_err_tot / (j+1)
608
+ sw.add_scalar("validation/mel_spec_error", val_err, steps)
609
+
610
+
611
+ generator.train()
612
+
613
+ steps += 1
614
+
615
+ scheduler_g.step()
616
+ scheduler_d.step()
617
+
618
+ # with accelerator.main_process_first():
619
+ accelerator.print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
620
+
621
+
622
+ def main():
623
+
624
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
625
+
626
+
627
+
628
+ parser = argparse.ArgumentParser()
629
+
630
+ parser.add_argument('--group_name', default=None)
631
+ parser.add_argument('--input_wavs_dir', default='')
632
+ parser.add_argument('--input_mels_dir', default='ft_dataset')
633
+ # parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/training.txt')
634
+ parser.add_argument('--input_training_file', default='/home/ubuntu/respair/audio_files_over_2sec_SUB.txt')
635
+ parser.add_argument('--input_validation_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/validation.txt')
636
+ parser.add_argument('--checkpoint_path', default='cp_ringformer_44.1khz_NUM2')
637
+ parser.add_argument('--config', default='config_v1.json')
638
+ parser.add_argument('--training_epochs', default=3100, type=int)
639
+ parser.add_argument('--stdout_interval', default=10, type=int)
640
+ parser.add_argument('--checkpoint_interval', default=1000, type=int)
641
+ parser.add_argument('--summary_interval', default=100, type=int)
642
+ parser.add_argument('--validation_interval', default=1000, type=int)
643
+ parser.add_argument('--fine_tuning', default=False, type=bool)
644
+
645
+ a = parser.parse_args()
646
+
647
+ accelerator = Accelerator(project_dir=a.checkpoint_path, split_batches=False,
648
+ kwargs_handlers=[ddp_kwargs])
649
+
650
+ accelerator.print('Initializing Training Process..')
651
+
652
+ with open(a.config) as f:
653
+ data = f.read()
654
+
655
+ json_config = json.loads(data)
656
+ h = AttrDict(json_config)
657
+ build_env(a.config, 'config.json', a.checkpoint_path)
658
+
659
+ torch.manual_seed(h.seed)
660
+ if torch.cuda.is_available():
661
+ torch.cuda.manual_seed(h.seed)
662
+ h.num_gpus = torch.cuda.device_count()
663
+ # h.batch_size = int(h.batch_size / h.num_gpus)
664
+ # print('Batch size per GPU :', h.batch_size)
665
+ else:
666
+ pass
667
+ # Pass accelerator to train function instead of rank
668
+ train(accelerator, a, h)
669
+
670
+ if __name__ == '__main__':
671
+ main()
RingFormer/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import matplotlib
4
+ import torch
5
+ from torch.nn.utils import weight_norm
6
+ matplotlib.use("Agg")
7
+ import matplotlib.pylab as plt
8
+
9
+
10
+ def plot_spectrogram(spectrogram):
11
+ fig, ax = plt.subplots(figsize=(10, 2))
12
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13
+ interpolation='none')
14
+ plt.colorbar(im, ax=ax)
15
+
16
+ fig.canvas.draw()
17
+ plt.close()
18
+
19
+ return fig
20
+
21
+
22
+ def init_weights(m, mean=0.0, std=0.01):
23
+ classname = m.__class__.__name__
24
+ if classname.find("Conv") != -1:
25
+ m.weight.data.normal_(mean, std)
26
+
27
+
28
+ def apply_weight_norm(m):
29
+ classname = m.__class__.__name__
30
+ if classname.find("Conv") != -1:
31
+ weight_norm(m)
32
+
33
+
34
+ def get_padding(kernel_size, dilation=1):
35
+ return int((kernel_size * dilation - dilation) / 2)
36
+
37
+
38
+ def load_checkpoint(filepath, device):
39
+ assert os.path.isfile(filepath)
40
+ print("Loading '{}'".format(filepath))
41
+ checkpoint_dict = torch.load(filepath, map_location=device)
42
+ print("Complete.")
43
+ return checkpoint_dict
44
+
45
+
46
+ def save_checkpoint(filepath, obj):
47
+ print("Saving checkpoint to {}".format(filepath))
48
+ torch.save(obj, filepath)
49
+ print("Complete.")
50
+
51
+
52
+ def scan_checkpoint(cp_dir, prefix):
53
+ pattern = os.path.join(cp_dir, prefix + '????????')
54
+ cp_list = glob.glob(pattern)
55
+ if len(cp_list) == 0:
56
+ return None
57
+ return sorted(cp_list)[-1]
58
+