Update losses.py
Browse files
losses.py
CHANGED
@@ -2,7 +2,10 @@ import torch
|
|
2 |
from torch import nn
|
3 |
import torch.nn.functional as F
|
4 |
import torchaudio
|
5 |
-
from transformers import AutoModel
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
class SpectralConvergengeLoss(torch.nn.Module):
|
@@ -22,25 +25,16 @@ class SpectralConvergengeLoss(torch.nn.Module):
|
|
22 |
"""
|
23 |
return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
|
24 |
|
25 |
-
|
26 |
class STFTLoss(torch.nn.Module):
|
27 |
"""STFT loss module."""
|
28 |
|
29 |
-
def __init__(
|
30 |
-
self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window
|
31 |
-
):
|
32 |
"""Initialize STFT loss module."""
|
33 |
super(STFTLoss, self).__init__()
|
34 |
self.fft_size = fft_size
|
35 |
self.shift_size = shift_size
|
36 |
self.win_length = win_length
|
37 |
-
self.to_mel = torchaudio.transforms.MelSpectrogram(
|
38 |
-
sample_rate=24000,
|
39 |
-
n_fft=fft_size,
|
40 |
-
win_length=win_length,
|
41 |
-
hop_length=shift_size,
|
42 |
-
window_fn=window,
|
43 |
-
)
|
44 |
|
45 |
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
46 |
|
@@ -56,25 +50,23 @@ class STFTLoss(torch.nn.Module):
|
|
56 |
x_mag = self.to_mel(x)
|
57 |
mean, std = -4, 4
|
58 |
x_mag = (torch.log(1e-5 + x_mag) - mean) / std
|
59 |
-
|
60 |
y_mag = self.to_mel(y)
|
61 |
mean, std = -4, 4
|
62 |
y_mag = (torch.log(1e-5 + y_mag) - mean) / std
|
63 |
-
|
64 |
-
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
65 |
return sc_loss
|
66 |
|
67 |
|
68 |
class MultiResolutionSTFTLoss(torch.nn.Module):
|
69 |
"""Multi resolution STFT loss module."""
|
70 |
|
71 |
-
def __init__(
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
window=torch.hann_window,
|
77 |
-
):
|
78 |
"""Initialize Multi resolution STFT loss module.
|
79 |
Args:
|
80 |
fft_sizes (list): List of FFT sizes.
|
@@ -104,15 +96,15 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
|
|
104 |
sc_loss /= len(self.stft_losses)
|
105 |
|
106 |
return sc_loss
|
107 |
-
|
108 |
-
|
109 |
def feature_loss(fmap_r, fmap_g):
|
110 |
loss = 0
|
111 |
for dr, dg in zip(fmap_r, fmap_g):
|
112 |
for rl, gl in zip(dr, dg):
|
113 |
loss += torch.mean(torch.abs(rl - gl))
|
114 |
|
115 |
-
return loss
|
116 |
|
117 |
|
118 |
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
@@ -120,9 +112,9 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
|
120 |
r_losses = []
|
121 |
g_losses = []
|
122 |
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
123 |
-
r_loss = torch.mean((1
|
124 |
g_loss = torch.mean(dg**2)
|
125 |
-
loss += r_loss + g_loss
|
126 |
r_losses.append(r_loss.item())
|
127 |
g_losses.append(g_loss.item())
|
128 |
|
@@ -133,42 +125,38 @@ def generator_loss(disc_outputs):
|
|
133 |
loss = 0
|
134 |
gen_losses = []
|
135 |
for dg in disc_outputs:
|
136 |
-
l = torch.mean((1
|
137 |
gen_losses.append(l)
|
138 |
loss += l
|
139 |
|
140 |
return loss, gen_losses
|
141 |
|
142 |
-
|
143 |
""" https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """
|
144 |
-
|
145 |
-
|
146 |
def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
|
147 |
loss = 0
|
148 |
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
149 |
tau = 0.04
|
150 |
-
m_DG = torch.median((dr
|
151 |
-
L_rel = torch.mean((((dr - dg) - m_DG)
|
152 |
loss += tau - F.relu(tau - L_rel)
|
153 |
return loss
|
154 |
|
155 |
-
|
156 |
def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
|
157 |
loss = 0
|
158 |
for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
|
159 |
tau = 0.04
|
160 |
-
m_DG = torch.median((dr
|
161 |
-
L_rel = torch.mean((((dr - dg) - m_DG)
|
162 |
loss += tau - F.relu(tau - L_rel)
|
163 |
return loss
|
164 |
|
165 |
-
|
166 |
class GeneratorLoss(torch.nn.Module):
|
|
|
167 |
def __init__(self, mpd, msd):
|
168 |
super(GeneratorLoss, self).__init__()
|
169 |
self.mpd = mpd
|
170 |
self.msd = msd
|
171 |
-
|
172 |
def forward(self, y, y_hat):
|
173 |
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
|
174 |
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
|
@@ -177,127 +165,224 @@ class GeneratorLoss(torch.nn.Module):
|
|
177 |
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
178 |
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
179 |
|
180 |
-
loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(
|
181 |
-
|
182 |
-
)
|
183 |
-
|
184 |
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel
|
185 |
-
|
186 |
return loss_gen_all.mean()
|
187 |
-
|
188 |
-
|
189 |
class DiscriminatorLoss(torch.nn.Module):
|
|
|
190 |
def __init__(self, mpd, msd):
|
191 |
super(DiscriminatorLoss, self).__init__()
|
192 |
self.mpd = mpd
|
193 |
self.msd = msd
|
194 |
-
|
195 |
def forward(self, y, y_hat):
|
196 |
# MPD
|
197 |
y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
|
198 |
-
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
|
199 |
-
y_df_hat_r, y_df_hat_g
|
200 |
-
)
|
201 |
# MSD
|
202 |
y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
|
203 |
-
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
|
204 |
-
|
205 |
-
)
|
206 |
|
207 |
-
loss_rel = discriminator_TPRLS_loss(
|
208 |
-
y_df_hat_r, y_df_hat_g
|
209 |
-
) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
210 |
|
211 |
d_loss = loss_disc_s + loss_disc_f + loss_rel
|
212 |
-
|
213 |
return d_loss.mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
|
216 |
class WavLMLoss(torch.nn.Module):
|
217 |
def __init__(self, model, wd, model_sr, slm_sr=16000):
|
218 |
super(WavLMLoss, self).__init__()
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
self.wd = wd
|
221 |
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
|
222 |
|
223 |
-
def forward(self, wav,
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
-
floss = 0
|
235 |
-
for er, eg in zip(wav_embeddings, y_rec_embeddings):
|
236 |
-
floss += torch.mean(torch.abs(er - eg))
|
237 |
|
238 |
-
return floss.mean()
|
239 |
|
240 |
def generator(self, y_rec):
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
.transpose(-1, -2)
|
248 |
-
.flatten(start_dim=1, end_dim=2)
|
249 |
-
)
|
250 |
-
y_df_hat_g = self.wd(y_rec_embeddings)
|
251 |
-
loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
|
252 |
-
|
253 |
-
return loss_gen
|
254 |
|
255 |
-
def discriminator(self, wav, y_rec):
|
256 |
with torch.no_grad():
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
input_values=y_rec_16, output_hidden_states=True
|
264 |
-
).hidden_states
|
265 |
-
|
266 |
-
y_embeddings = (
|
267 |
-
torch.stack(wav_embeddings, dim=1)
|
268 |
-
.transpose(-1, -2)
|
269 |
-
.flatten(start_dim=1, end_dim=2)
|
270 |
-
)
|
271 |
-
y_rec_embeddings = (
|
272 |
-
torch.stack(y_rec_embeddings, dim=1)
|
273 |
-
.transpose(-1, -2)
|
274 |
-
.flatten(start_dim=1, end_dim=2)
|
275 |
-
)
|
276 |
-
|
277 |
-
y_d_rs = self.wd(y_embeddings)
|
278 |
-
y_d_gs = self.wd(y_rec_embeddings)
|
279 |
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
-
|
283 |
-
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
loss_disc_f = r_loss + g_loss
|
|
|
|
|
|
|
|
|
286 |
|
287 |
-
return loss_disc_f.mean()
|
288 |
|
289 |
def discriminator_forward(self, wav):
|
290 |
-
|
291 |
-
|
292 |
-
wav_embeddings = self.wavlm(
|
293 |
-
input_values=wav_16, output_hidden_states=True
|
294 |
-
).hidden_states
|
295 |
-
y_embeddings = (
|
296 |
-
torch.stack(wav_embeddings, dim=1)
|
297 |
-
.transpose(-1, -2)
|
298 |
-
.flatten(start_dim=1, end_dim=2)
|
299 |
-
)
|
300 |
|
301 |
-
y_d_rs = self.wd(y_embeddings)
|
302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
return y_d_rs
|
|
|
|
2 |
from torch import nn
|
3 |
import torch.nn.functional as F
|
4 |
import torchaudio
|
5 |
+
from transformers import AutoModel, WhisperConfig, WhisperPreTrainedModel
|
6 |
+
import whisper
|
7 |
+
|
8 |
+
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
9 |
|
10 |
|
11 |
class SpectralConvergengeLoss(torch.nn.Module):
|
|
|
25 |
"""
|
26 |
return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
|
27 |
|
|
|
28 |
class STFTLoss(torch.nn.Module):
|
29 |
"""STFT loss module."""
|
30 |
|
31 |
+
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window):
|
|
|
|
|
32 |
"""Initialize STFT loss module."""
|
33 |
super(STFTLoss, self).__init__()
|
34 |
self.fft_size = fft_size
|
35 |
self.shift_size = shift_size
|
36 |
self.win_length = win_length
|
37 |
+
self.to_mel = torchaudio.transforms.MelSpectrogram(sample_rate=24000, n_fft=fft_size, win_length=win_length, hop_length=shift_size, window_fn=window)
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
40 |
|
|
|
50 |
x_mag = self.to_mel(x)
|
51 |
mean, std = -4, 4
|
52 |
x_mag = (torch.log(1e-5 + x_mag) - mean) / std
|
53 |
+
|
54 |
y_mag = self.to_mel(y)
|
55 |
mean, std = -4, 4
|
56 |
y_mag = (torch.log(1e-5 + y_mag) - mean) / std
|
57 |
+
|
58 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
59 |
return sc_loss
|
60 |
|
61 |
|
62 |
class MultiResolutionSTFTLoss(torch.nn.Module):
|
63 |
"""Multi resolution STFT loss module."""
|
64 |
|
65 |
+
def __init__(self,
|
66 |
+
fft_sizes=[1024, 2048, 512],
|
67 |
+
hop_sizes=[120, 240, 50],
|
68 |
+
win_lengths=[600, 1200, 240],
|
69 |
+
window=torch.hann_window):
|
|
|
|
|
70 |
"""Initialize Multi resolution STFT loss module.
|
71 |
Args:
|
72 |
fft_sizes (list): List of FFT sizes.
|
|
|
96 |
sc_loss /= len(self.stft_losses)
|
97 |
|
98 |
return sc_loss
|
99 |
+
|
100 |
+
|
101 |
def feature_loss(fmap_r, fmap_g):
|
102 |
loss = 0
|
103 |
for dr, dg in zip(fmap_r, fmap_g):
|
104 |
for rl, gl in zip(dr, dg):
|
105 |
loss += torch.mean(torch.abs(rl - gl))
|
106 |
|
107 |
+
return loss*2
|
108 |
|
109 |
|
110 |
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
|
|
112 |
r_losses = []
|
113 |
g_losses = []
|
114 |
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
115 |
+
r_loss = torch.mean((1-dr)**2)
|
116 |
g_loss = torch.mean(dg**2)
|
117 |
+
loss += (r_loss + g_loss)
|
118 |
r_losses.append(r_loss.item())
|
119 |
g_losses.append(g_loss.item())
|
120 |
|
|
|
125 |
loss = 0
|
126 |
gen_losses = []
|
127 |
for dg in disc_outputs:
|
128 |
+
l = torch.mean((1-dg)**2)
|
129 |
gen_losses.append(l)
|
130 |
loss += l
|
131 |
|
132 |
return loss, gen_losses
|
133 |
|
|
|
134 |
""" https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """
|
|
|
|
|
135 |
def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
|
136 |
loss = 0
|
137 |
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
138 |
tau = 0.04
|
139 |
+
m_DG = torch.median((dr-dg))
|
140 |
+
L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
|
141 |
loss += tau - F.relu(tau - L_rel)
|
142 |
return loss
|
143 |
|
|
|
144 |
def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
|
145 |
loss = 0
|
146 |
for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
|
147 |
tau = 0.04
|
148 |
+
m_DG = torch.median((dr-dg))
|
149 |
+
L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
|
150 |
loss += tau - F.relu(tau - L_rel)
|
151 |
return loss
|
152 |
|
|
|
153 |
class GeneratorLoss(torch.nn.Module):
|
154 |
+
|
155 |
def __init__(self, mpd, msd):
|
156 |
super(GeneratorLoss, self).__init__()
|
157 |
self.mpd = mpd
|
158 |
self.msd = msd
|
159 |
+
|
160 |
def forward(self, y, y_hat):
|
161 |
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
|
162 |
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
|
|
|
165 |
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
166 |
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
167 |
|
168 |
+
loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
169 |
+
|
|
|
|
|
170 |
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel
|
171 |
+
|
172 |
return loss_gen_all.mean()
|
173 |
+
|
|
|
174 |
class DiscriminatorLoss(torch.nn.Module):
|
175 |
+
|
176 |
def __init__(self, mpd, msd):
|
177 |
super(DiscriminatorLoss, self).__init__()
|
178 |
self.mpd = mpd
|
179 |
self.msd = msd
|
180 |
+
|
181 |
def forward(self, y, y_hat):
|
182 |
# MPD
|
183 |
y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
|
184 |
+
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
|
|
|
|
185 |
# MSD
|
186 |
y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
|
187 |
+
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
188 |
+
|
189 |
+
loss_rel = discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
190 |
|
|
|
|
|
|
|
191 |
|
192 |
d_loss = loss_disc_s + loss_disc_f + loss_rel
|
193 |
+
|
194 |
return d_loss.mean()
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
# #####################
|
200 |
+
# MIXED PRECISION
|
201 |
+
|
202 |
+
|
203 |
+
class WhisperEncoderOnly(WhisperPreTrainedModel):
|
204 |
+
def __init__(self, config: WhisperConfig):
|
205 |
+
super().__init__(config)
|
206 |
+
self.encoder = WhisperEncoder(config)
|
207 |
+
|
208 |
+
def forward(self, input_features, attention_mask=None):
|
209 |
+
return self.encoder(input_features, attention_mask)
|
210 |
+
|
211 |
|
212 |
|
213 |
class WavLMLoss(torch.nn.Module):
|
214 |
def __init__(self, model, wd, model_sr, slm_sr=16000):
|
215 |
super(WavLMLoss, self).__init__()
|
216 |
+
|
217 |
+
config = WhisperConfig.from_pretrained("Respair/Whisper_Large_v2_Encoder_Block")
|
218 |
+
|
219 |
+
# this will load the full model and keep only the encoder
|
220 |
+
full_model = WhisperEncoderOnly.from_pretrained("openai/whisper-large-v2", config=config, device_map='auto',torch_dtype=torch.bfloat16)
|
221 |
+
model = WhisperEncoderOnly(config)
|
222 |
+
model.encoder.load_state_dict(full_model.encoder.state_dict())
|
223 |
+
del full_model
|
224 |
+
|
225 |
+
|
226 |
+
self.wavlm = model.to(torch.bfloat16)
|
227 |
self.wd = wd
|
228 |
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
|
229 |
|
230 |
+
def forward(self, wav, y_rec, generator=False, discriminator=False, discriminator_forward=False):
|
231 |
+
|
232 |
+
if generator:
|
233 |
+
y_rec = y_rec.squeeze(1)
|
234 |
+
|
235 |
+
|
236 |
+
y_rec = whisper.pad_or_trim(y_rec)
|
237 |
+
y_rec = whisper.log_mel_spectrogram(y_rec)
|
238 |
+
|
239 |
+
with torch.no_grad():
|
240 |
+
y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
|
241 |
+
y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
|
242 |
+
y_df_hat_g = self.wd(y_rec_embeddings.to(torch.float32))
|
243 |
+
loss_gen = torch.mean((1-y_df_hat_g)**2)
|
244 |
+
|
245 |
+
return loss_gen.to(torch.float32)
|
246 |
+
|
247 |
+
elif discriminator:
|
248 |
+
|
249 |
+
wav = wav.squeeze(1)
|
250 |
+
y_rec = y_rec.squeeze(1)
|
251 |
+
|
252 |
+
wav = whisper.pad_or_trim(wav)
|
253 |
+
wav = whisper.log_mel_spectrogram(wav)
|
254 |
+
|
255 |
+
y_rec = whisper.pad_or_trim(y_rec)
|
256 |
+
y_rec = whisper.log_mel_spectrogram(y_rec)
|
257 |
+
|
258 |
+
with torch.no_grad():
|
259 |
+
wav_embeddings = self.wavlm.encoder(wav.to(torch.bfloat16), output_hidden_states=True).hidden_states
|
260 |
+
y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
|
261 |
+
|
262 |
+
y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
|
263 |
+
y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
|
264 |
+
|
265 |
+
y_d_rs = self.wd(y_embeddings.to(torch.float32))
|
266 |
+
y_d_gs = self.wd(y_rec_embeddings.to(torch.float32))
|
267 |
+
|
268 |
+
y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
|
269 |
+
|
270 |
+
r_loss = torch.mean((1-y_df_hat_r)**2)
|
271 |
+
g_loss = torch.mean((y_df_hat_g)**2)
|
272 |
+
|
273 |
+
loss_disc_f = r_loss + g_loss
|
274 |
+
|
275 |
+
return loss_disc_f.mean().to(torch.float32)
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
elif discriminator_forward:
|
280 |
+
# Squeeze the channel dimension if it's unnecessary
|
281 |
+
wav = wav.squeeze(1) # Adjust this line if the channel dimension is not at dim=1
|
282 |
+
|
283 |
+
|
284 |
+
with torch.no_grad():
|
285 |
+
|
286 |
+
wav_16 = self.resample(wav)
|
287 |
+
wav_16 = whisper.pad_or_trim(wav_16)
|
288 |
+
wav_16 = whisper.log_mel_spectrogram(wav_16)
|
289 |
+
|
290 |
+
wav_embeddings = self.wavlm.encoder(wav_16.to(torch.bfloat16) , output_hidden_states=True).hidden_states
|
291 |
+
y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
|
292 |
+
|
293 |
+
y_d_rs = self.wd(y_embeddings.to(torch.float32))
|
294 |
+
|
295 |
+
return y_d_rs
|
296 |
+
|
297 |
+
else:
|
298 |
+
|
299 |
+
wav = wav.squeeze(1)
|
300 |
+
y_rec = y_rec.squeeze(1)
|
301 |
+
|
302 |
+
wav = whisper.pad_or_trim(wav)
|
303 |
+
wav = whisper.log_mel_spectrogram(wav)
|
304 |
+
|
305 |
+
y_rec = whisper.pad_or_trim(y_rec)
|
306 |
+
y_rec = whisper.log_mel_spectrogram(y_rec)
|
307 |
+
|
308 |
+
with torch.no_grad():
|
309 |
+
wav_embeddings = self.wavlm.encoder(wav.to(torch.bfloat16), output_hidden_states=True).hidden_states
|
310 |
+
|
311 |
+
y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
|
312 |
+
|
313 |
+
|
314 |
+
floss = 0
|
315 |
+
for er, eg in zip([e.to(torch.float32) for e in wav_embeddings], [e.to(torch.float32) for e in y_rec_embeddings]):
|
316 |
+
floss += torch.mean(torch.abs(er - eg))
|
317 |
+
|
318 |
+
return floss.mean()
|
319 |
|
|
|
|
|
|
|
320 |
|
|
|
321 |
|
322 |
def generator(self, y_rec):
|
323 |
+
|
324 |
+
y_rec = y_rec.squeeze(1)
|
325 |
+
|
326 |
+
|
327 |
+
y_rec = whisper.pad_or_trim(y_rec)
|
328 |
+
y_rec = whisper.log_mel_spectrogram(y_rec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
|
|
|
330 |
with torch.no_grad():
|
331 |
+
y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
|
332 |
+
y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
|
333 |
+
y_df_hat_g = self.wd(y_rec_embeddings.to(torch.float32))
|
334 |
+
loss_gen = torch.mean((1-y_df_hat_g)**2)
|
335 |
+
|
336 |
+
return loss_gen.to(torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
|
338 |
+
def discriminator(self, wav, y_rec):
|
339 |
+
|
340 |
+
wav = wav.squeeze(1)
|
341 |
+
y_rec = y_rec.squeeze(1)
|
342 |
+
|
343 |
+
wav = whisper.pad_or_trim(wav)
|
344 |
+
wav = whisper.log_mel_spectrogram(wav)
|
345 |
+
|
346 |
+
y_rec = whisper.pad_or_trim(y_rec)
|
347 |
+
y_rec = whisper.log_mel_spectrogram(y_rec)
|
348 |
+
|
349 |
+
with torch.no_grad():
|
350 |
+
wav_embeddings = self.wavlm.encoder(wav.to(torch.bfloat16), output_hidden_states=True).hidden_states
|
351 |
+
y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
|
352 |
|
353 |
+
y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
|
354 |
+
y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
|
355 |
|
356 |
+
y_d_rs = self.wd(y_embeddings.to(torch.float32))
|
357 |
+
y_d_gs = self.wd(y_rec_embeddings.to(torch.float32))
|
358 |
+
|
359 |
+
y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
|
360 |
+
|
361 |
+
r_loss = torch.mean((1-y_df_hat_r)**2)
|
362 |
+
g_loss = torch.mean((y_df_hat_g)**2)
|
363 |
+
|
364 |
loss_disc_f = r_loss + g_loss
|
365 |
+
|
366 |
+
return loss_disc_f.mean().to(torch.float32)
|
367 |
+
|
368 |
+
|
369 |
|
|
|
370 |
|
371 |
def discriminator_forward(self, wav):
|
372 |
+
# Squeeze the channel dimension if it's unnecessary
|
373 |
+
wav = wav.squeeze(1) # Adjust this line if the channel dimension is not at dim=1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
|
|
|
375 |
|
376 |
+
with torch.no_grad():
|
377 |
+
|
378 |
+
wav_16 = self.resample(wav)
|
379 |
+
wav_16 = whisper.pad_or_trim(wav_16)
|
380 |
+
wav_16 = whisper.log_mel_spectrogram(wav_16)
|
381 |
+
|
382 |
+
wav_embeddings = self.wavlm.encoder(wav_16.to(torch.bfloat16) , output_hidden_states=True).hidden_states
|
383 |
+
y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
|
384 |
+
|
385 |
+
y_d_rs = self.wd(y_embeddings.to(torch.float32))
|
386 |
+
|
387 |
return y_d_rs
|
388 |
+
|