Respair commited on
Commit
5675557
·
verified ·
1 Parent(s): 0e5d53e

Update losses.py

Browse files
Files changed (1) hide show
  1. losses.py +207 -122
losses.py CHANGED
@@ -2,7 +2,10 @@ import torch
2
  from torch import nn
3
  import torch.nn.functional as F
4
  import torchaudio
5
- from transformers import AutoModel
 
 
 
6
 
7
 
8
  class SpectralConvergengeLoss(torch.nn.Module):
@@ -22,25 +25,16 @@ class SpectralConvergengeLoss(torch.nn.Module):
22
  """
23
  return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
24
 
25
-
26
  class STFTLoss(torch.nn.Module):
27
  """STFT loss module."""
28
 
29
- def __init__(
30
- self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window
31
- ):
32
  """Initialize STFT loss module."""
33
  super(STFTLoss, self).__init__()
34
  self.fft_size = fft_size
35
  self.shift_size = shift_size
36
  self.win_length = win_length
37
- self.to_mel = torchaudio.transforms.MelSpectrogram(
38
- sample_rate=24000,
39
- n_fft=fft_size,
40
- win_length=win_length,
41
- hop_length=shift_size,
42
- window_fn=window,
43
- )
44
 
45
  self.spectral_convergenge_loss = SpectralConvergengeLoss()
46
 
@@ -56,25 +50,23 @@ class STFTLoss(torch.nn.Module):
56
  x_mag = self.to_mel(x)
57
  mean, std = -4, 4
58
  x_mag = (torch.log(1e-5 + x_mag) - mean) / std
59
-
60
  y_mag = self.to_mel(y)
61
  mean, std = -4, 4
62
  y_mag = (torch.log(1e-5 + y_mag) - mean) / std
63
-
64
- sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
65
  return sc_loss
66
 
67
 
68
  class MultiResolutionSTFTLoss(torch.nn.Module):
69
  """Multi resolution STFT loss module."""
70
 
71
- def __init__(
72
- self,
73
- fft_sizes=[1024, 2048, 512],
74
- hop_sizes=[120, 240, 50],
75
- win_lengths=[600, 1200, 240],
76
- window=torch.hann_window,
77
- ):
78
  """Initialize Multi resolution STFT loss module.
79
  Args:
80
  fft_sizes (list): List of FFT sizes.
@@ -104,15 +96,15 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
104
  sc_loss /= len(self.stft_losses)
105
 
106
  return sc_loss
107
-
108
-
109
  def feature_loss(fmap_r, fmap_g):
110
  loss = 0
111
  for dr, dg in zip(fmap_r, fmap_g):
112
  for rl, gl in zip(dr, dg):
113
  loss += torch.mean(torch.abs(rl - gl))
114
 
115
- return loss * 2
116
 
117
 
118
  def discriminator_loss(disc_real_outputs, disc_generated_outputs):
@@ -120,9 +112,9 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
120
  r_losses = []
121
  g_losses = []
122
  for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
123
- r_loss = torch.mean((1 - dr) ** 2)
124
  g_loss = torch.mean(dg**2)
125
- loss += r_loss + g_loss
126
  r_losses.append(r_loss.item())
127
  g_losses.append(g_loss.item())
128
 
@@ -133,42 +125,38 @@ def generator_loss(disc_outputs):
133
  loss = 0
134
  gen_losses = []
135
  for dg in disc_outputs:
136
- l = torch.mean((1 - dg) ** 2)
137
  gen_losses.append(l)
138
  loss += l
139
 
140
  return loss, gen_losses
141
 
142
-
143
  """ https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """
144
-
145
-
146
  def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
147
  loss = 0
148
  for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
149
  tau = 0.04
150
- m_DG = torch.median((dr - dg))
151
- L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
152
  loss += tau - F.relu(tau - L_rel)
153
  return loss
154
 
155
-
156
  def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
157
  loss = 0
158
  for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
159
  tau = 0.04
160
- m_DG = torch.median((dr - dg))
161
- L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
162
  loss += tau - F.relu(tau - L_rel)
163
  return loss
164
 
165
-
166
  class GeneratorLoss(torch.nn.Module):
 
167
  def __init__(self, mpd, msd):
168
  super(GeneratorLoss, self).__init__()
169
  self.mpd = mpd
170
  self.msd = msd
171
-
172
  def forward(self, y, y_hat):
173
  y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
174
  y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
@@ -177,127 +165,224 @@ class GeneratorLoss(torch.nn.Module):
177
  loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
178
  loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
179
 
180
- loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(
181
- y_ds_hat_r, y_ds_hat_g
182
- )
183
-
184
  loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel
185
-
186
  return loss_gen_all.mean()
187
-
188
-
189
  class DiscriminatorLoss(torch.nn.Module):
 
190
  def __init__(self, mpd, msd):
191
  super(DiscriminatorLoss, self).__init__()
192
  self.mpd = mpd
193
  self.msd = msd
194
-
195
  def forward(self, y, y_hat):
196
  # MPD
197
  y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
198
- loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
199
- y_df_hat_r, y_df_hat_g
200
- )
201
  # MSD
202
  y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
203
- loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
204
- y_ds_hat_r, y_ds_hat_g
205
- )
206
 
207
- loss_rel = discriminator_TPRLS_loss(
208
- y_df_hat_r, y_df_hat_g
209
- ) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
210
 
211
  d_loss = loss_disc_s + loss_disc_f + loss_rel
212
-
213
  return d_loss.mean()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
 
216
  class WavLMLoss(torch.nn.Module):
217
  def __init__(self, model, wd, model_sr, slm_sr=16000):
218
  super(WavLMLoss, self).__init__()
219
- self.wavlm = AutoModel.from_pretrained(model)
 
 
 
 
 
 
 
 
 
 
220
  self.wd = wd
221
  self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
222
 
223
- def forward(self, wav, y_rec):
224
- with torch.no_grad():
225
- wav_16 = self.resample(wav)
226
- wav_embeddings = self.wavlm(
227
- input_values=wav_16, output_hidden_states=True
228
- ).hidden_states
229
- y_rec_16 = self.resample(y_rec)
230
- y_rec_embeddings = self.wavlm(
231
- input_values=y_rec_16.squeeze(), output_hidden_states=True
232
- ).hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- floss = 0
235
- for er, eg in zip(wav_embeddings, y_rec_embeddings):
236
- floss += torch.mean(torch.abs(er - eg))
237
 
238
- return floss.mean()
239
 
240
  def generator(self, y_rec):
241
- y_rec_16 = self.resample(y_rec)
242
- y_rec_embeddings = self.wavlm(
243
- input_values=y_rec_16, output_hidden_states=True
244
- ).hidden_states
245
- y_rec_embeddings = (
246
- torch.stack(y_rec_embeddings, dim=1)
247
- .transpose(-1, -2)
248
- .flatten(start_dim=1, end_dim=2)
249
- )
250
- y_df_hat_g = self.wd(y_rec_embeddings)
251
- loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
252
-
253
- return loss_gen
254
 
255
- def discriminator(self, wav, y_rec):
256
  with torch.no_grad():
257
- wav_16 = self.resample(wav)
258
- wav_embeddings = self.wavlm(
259
- input_values=wav_16, output_hidden_states=True
260
- ).hidden_states
261
- y_rec_16 = self.resample(y_rec)
262
- y_rec_embeddings = self.wavlm(
263
- input_values=y_rec_16, output_hidden_states=True
264
- ).hidden_states
265
-
266
- y_embeddings = (
267
- torch.stack(wav_embeddings, dim=1)
268
- .transpose(-1, -2)
269
- .flatten(start_dim=1, end_dim=2)
270
- )
271
- y_rec_embeddings = (
272
- torch.stack(y_rec_embeddings, dim=1)
273
- .transpose(-1, -2)
274
- .flatten(start_dim=1, end_dim=2)
275
- )
276
-
277
- y_d_rs = self.wd(y_embeddings)
278
- y_d_gs = self.wd(y_rec_embeddings)
279
 
280
- y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
- r_loss = torch.mean((1 - y_df_hat_r) ** 2)
283
- g_loss = torch.mean((y_df_hat_g) ** 2)
284
 
 
 
 
 
 
 
 
 
285
  loss_disc_f = r_loss + g_loss
 
 
 
 
286
 
287
- return loss_disc_f.mean()
288
 
289
  def discriminator_forward(self, wav):
290
- with torch.no_grad():
291
- wav_16 = self.resample(wav)
292
- wav_embeddings = self.wavlm(
293
- input_values=wav_16, output_hidden_states=True
294
- ).hidden_states
295
- y_embeddings = (
296
- torch.stack(wav_embeddings, dim=1)
297
- .transpose(-1, -2)
298
- .flatten(start_dim=1, end_dim=2)
299
- )
300
 
301
- y_d_rs = self.wd(y_embeddings)
302
 
 
 
 
 
 
 
 
 
 
 
 
303
  return y_d_rs
 
 
2
  from torch import nn
3
  import torch.nn.functional as F
4
  import torchaudio
5
+ from transformers import AutoModel, WhisperConfig, WhisperPreTrainedModel
6
+ import whisper
7
+
8
+ from transformers.models.whisper.modeling_whisper import WhisperEncoder
9
 
10
 
11
  class SpectralConvergengeLoss(torch.nn.Module):
 
25
  """
26
  return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
27
 
 
28
  class STFTLoss(torch.nn.Module):
29
  """STFT loss module."""
30
 
31
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window):
 
 
32
  """Initialize STFT loss module."""
33
  super(STFTLoss, self).__init__()
34
  self.fft_size = fft_size
35
  self.shift_size = shift_size
36
  self.win_length = win_length
37
+ self.to_mel = torchaudio.transforms.MelSpectrogram(sample_rate=24000, n_fft=fft_size, win_length=win_length, hop_length=shift_size, window_fn=window)
 
 
 
 
 
 
38
 
39
  self.spectral_convergenge_loss = SpectralConvergengeLoss()
40
 
 
50
  x_mag = self.to_mel(x)
51
  mean, std = -4, 4
52
  x_mag = (torch.log(1e-5 + x_mag) - mean) / std
53
+
54
  y_mag = self.to_mel(y)
55
  mean, std = -4, 4
56
  y_mag = (torch.log(1e-5 + y_mag) - mean) / std
57
+
58
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
59
  return sc_loss
60
 
61
 
62
  class MultiResolutionSTFTLoss(torch.nn.Module):
63
  """Multi resolution STFT loss module."""
64
 
65
+ def __init__(self,
66
+ fft_sizes=[1024, 2048, 512],
67
+ hop_sizes=[120, 240, 50],
68
+ win_lengths=[600, 1200, 240],
69
+ window=torch.hann_window):
 
 
70
  """Initialize Multi resolution STFT loss module.
71
  Args:
72
  fft_sizes (list): List of FFT sizes.
 
96
  sc_loss /= len(self.stft_losses)
97
 
98
  return sc_loss
99
+
100
+
101
  def feature_loss(fmap_r, fmap_g):
102
  loss = 0
103
  for dr, dg in zip(fmap_r, fmap_g):
104
  for rl, gl in zip(dr, dg):
105
  loss += torch.mean(torch.abs(rl - gl))
106
 
107
+ return loss*2
108
 
109
 
110
  def discriminator_loss(disc_real_outputs, disc_generated_outputs):
 
112
  r_losses = []
113
  g_losses = []
114
  for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
115
+ r_loss = torch.mean((1-dr)**2)
116
  g_loss = torch.mean(dg**2)
117
+ loss += (r_loss + g_loss)
118
  r_losses.append(r_loss.item())
119
  g_losses.append(g_loss.item())
120
 
 
125
  loss = 0
126
  gen_losses = []
127
  for dg in disc_outputs:
128
+ l = torch.mean((1-dg)**2)
129
  gen_losses.append(l)
130
  loss += l
131
 
132
  return loss, gen_losses
133
 
 
134
  """ https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """
 
 
135
  def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
136
  loss = 0
137
  for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
138
  tau = 0.04
139
+ m_DG = torch.median((dr-dg))
140
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
141
  loss += tau - F.relu(tau - L_rel)
142
  return loss
143
 
 
144
  def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
145
  loss = 0
146
  for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
147
  tau = 0.04
148
+ m_DG = torch.median((dr-dg))
149
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
150
  loss += tau - F.relu(tau - L_rel)
151
  return loss
152
 
 
153
  class GeneratorLoss(torch.nn.Module):
154
+
155
  def __init__(self, mpd, msd):
156
  super(GeneratorLoss, self).__init__()
157
  self.mpd = mpd
158
  self.msd = msd
159
+
160
  def forward(self, y, y_hat):
161
  y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
162
  y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
 
165
  loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
166
  loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
167
 
168
+ loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
169
+
 
 
170
  loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel
171
+
172
  return loss_gen_all.mean()
173
+
 
174
  class DiscriminatorLoss(torch.nn.Module):
175
+
176
  def __init__(self, mpd, msd):
177
  super(DiscriminatorLoss, self).__init__()
178
  self.mpd = mpd
179
  self.msd = msd
180
+
181
  def forward(self, y, y_hat):
182
  # MPD
183
  y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
184
+ loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
 
 
185
  # MSD
186
  y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
187
+ loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
188
+
189
+ loss_rel = discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
190
 
 
 
 
191
 
192
  d_loss = loss_disc_s + loss_disc_f + loss_rel
193
+
194
  return d_loss.mean()
195
+
196
+
197
+
198
+
199
+ # #####################
200
+ # MIXED PRECISION
201
+
202
+
203
+ class WhisperEncoderOnly(WhisperPreTrainedModel):
204
+ def __init__(self, config: WhisperConfig):
205
+ super().__init__(config)
206
+ self.encoder = WhisperEncoder(config)
207
+
208
+ def forward(self, input_features, attention_mask=None):
209
+ return self.encoder(input_features, attention_mask)
210
+
211
 
212
 
213
  class WavLMLoss(torch.nn.Module):
214
  def __init__(self, model, wd, model_sr, slm_sr=16000):
215
  super(WavLMLoss, self).__init__()
216
+
217
+ config = WhisperConfig.from_pretrained("Respair/Whisper_Large_v2_Encoder_Block")
218
+
219
+ # this will load the full model and keep only the encoder
220
+ full_model = WhisperEncoderOnly.from_pretrained("openai/whisper-large-v2", config=config, device_map='auto',torch_dtype=torch.bfloat16)
221
+ model = WhisperEncoderOnly(config)
222
+ model.encoder.load_state_dict(full_model.encoder.state_dict())
223
+ del full_model
224
+
225
+
226
+ self.wavlm = model.to(torch.bfloat16)
227
  self.wd = wd
228
  self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
229
 
230
+ def forward(self, wav, y_rec, generator=False, discriminator=False, discriminator_forward=False):
231
+
232
+ if generator:
233
+ y_rec = y_rec.squeeze(1)
234
+
235
+
236
+ y_rec = whisper.pad_or_trim(y_rec)
237
+ y_rec = whisper.log_mel_spectrogram(y_rec)
238
+
239
+ with torch.no_grad():
240
+ y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
241
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
242
+ y_df_hat_g = self.wd(y_rec_embeddings.to(torch.float32))
243
+ loss_gen = torch.mean((1-y_df_hat_g)**2)
244
+
245
+ return loss_gen.to(torch.float32)
246
+
247
+ elif discriminator:
248
+
249
+ wav = wav.squeeze(1)
250
+ y_rec = y_rec.squeeze(1)
251
+
252
+ wav = whisper.pad_or_trim(wav)
253
+ wav = whisper.log_mel_spectrogram(wav)
254
+
255
+ y_rec = whisper.pad_or_trim(y_rec)
256
+ y_rec = whisper.log_mel_spectrogram(y_rec)
257
+
258
+ with torch.no_grad():
259
+ wav_embeddings = self.wavlm.encoder(wav.to(torch.bfloat16), output_hidden_states=True).hidden_states
260
+ y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
261
+
262
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
263
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
264
+
265
+ y_d_rs = self.wd(y_embeddings.to(torch.float32))
266
+ y_d_gs = self.wd(y_rec_embeddings.to(torch.float32))
267
+
268
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
269
+
270
+ r_loss = torch.mean((1-y_df_hat_r)**2)
271
+ g_loss = torch.mean((y_df_hat_g)**2)
272
+
273
+ loss_disc_f = r_loss + g_loss
274
+
275
+ return loss_disc_f.mean().to(torch.float32)
276
+
277
+
278
+
279
+ elif discriminator_forward:
280
+ # Squeeze the channel dimension if it's unnecessary
281
+ wav = wav.squeeze(1) # Adjust this line if the channel dimension is not at dim=1
282
+
283
+
284
+ with torch.no_grad():
285
+
286
+ wav_16 = self.resample(wav)
287
+ wav_16 = whisper.pad_or_trim(wav_16)
288
+ wav_16 = whisper.log_mel_spectrogram(wav_16)
289
+
290
+ wav_embeddings = self.wavlm.encoder(wav_16.to(torch.bfloat16) , output_hidden_states=True).hidden_states
291
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
292
+
293
+ y_d_rs = self.wd(y_embeddings.to(torch.float32))
294
+
295
+ return y_d_rs
296
+
297
+ else:
298
+
299
+ wav = wav.squeeze(1)
300
+ y_rec = y_rec.squeeze(1)
301
+
302
+ wav = whisper.pad_or_trim(wav)
303
+ wav = whisper.log_mel_spectrogram(wav)
304
+
305
+ y_rec = whisper.pad_or_trim(y_rec)
306
+ y_rec = whisper.log_mel_spectrogram(y_rec)
307
+
308
+ with torch.no_grad():
309
+ wav_embeddings = self.wavlm.encoder(wav.to(torch.bfloat16), output_hidden_states=True).hidden_states
310
+
311
+ y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
312
+
313
+
314
+ floss = 0
315
+ for er, eg in zip([e.to(torch.float32) for e in wav_embeddings], [e.to(torch.float32) for e in y_rec_embeddings]):
316
+ floss += torch.mean(torch.abs(er - eg))
317
+
318
+ return floss.mean()
319
 
 
 
 
320
 
 
321
 
322
  def generator(self, y_rec):
323
+
324
+ y_rec = y_rec.squeeze(1)
325
+
326
+
327
+ y_rec = whisper.pad_or_trim(y_rec)
328
+ y_rec = whisper.log_mel_spectrogram(y_rec)
 
 
 
 
 
 
 
329
 
 
330
  with torch.no_grad():
331
+ y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
332
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
333
+ y_df_hat_g = self.wd(y_rec_embeddings.to(torch.float32))
334
+ loss_gen = torch.mean((1-y_df_hat_g)**2)
335
+
336
+ return loss_gen.to(torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ def discriminator(self, wav, y_rec):
339
+
340
+ wav = wav.squeeze(1)
341
+ y_rec = y_rec.squeeze(1)
342
+
343
+ wav = whisper.pad_or_trim(wav)
344
+ wav = whisper.log_mel_spectrogram(wav)
345
+
346
+ y_rec = whisper.pad_or_trim(y_rec)
347
+ y_rec = whisper.log_mel_spectrogram(y_rec)
348
+
349
+ with torch.no_grad():
350
+ wav_embeddings = self.wavlm.encoder(wav.to(torch.bfloat16), output_hidden_states=True).hidden_states
351
+ y_rec_embeddings = self.wavlm.encoder(y_rec.to(torch.bfloat16), output_hidden_states=True).hidden_states
352
 
353
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
354
+ y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
355
 
356
+ y_d_rs = self.wd(y_embeddings.to(torch.float32))
357
+ y_d_gs = self.wd(y_rec_embeddings.to(torch.float32))
358
+
359
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
360
+
361
+ r_loss = torch.mean((1-y_df_hat_r)**2)
362
+ g_loss = torch.mean((y_df_hat_g)**2)
363
+
364
  loss_disc_f = r_loss + g_loss
365
+
366
+ return loss_disc_f.mean().to(torch.float32)
367
+
368
+
369
 
 
370
 
371
  def discriminator_forward(self, wav):
372
+ # Squeeze the channel dimension if it's unnecessary
373
+ wav = wav.squeeze(1) # Adjust this line if the channel dimension is not at dim=1
 
 
 
 
 
 
 
 
374
 
 
375
 
376
+ with torch.no_grad():
377
+
378
+ wav_16 = self.resample(wav)
379
+ wav_16 = whisper.pad_or_trim(wav_16)
380
+ wav_16 = whisper.log_mel_spectrogram(wav_16)
381
+
382
+ wav_embeddings = self.wavlm.encoder(wav_16.to(torch.bfloat16) , output_hidden_states=True).hidden_states
383
+ y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
384
+
385
+ y_d_rs = self.wd(y_embeddings.to(torch.float32))
386
+
387
  return y_d_rs
388
+