yangwang825 commited on
Commit
a6aed96
·
verified ·
1 Parent(s): be28212

Upload model

Browse files
Files changed (3) hide show
  1. config.json +4 -2
  2. model.safetensors +3 -0
  3. modeling_hubert_spkreg.py +613 -0
config.json CHANGED
@@ -3,11 +3,12 @@
3
  "activation_dropout": 0.1,
4
  "apply_spec_augment": true,
5
  "architectures": [
6
- "HubertModel"
7
  ],
8
  "attention_dropout": 0.1,
9
  "auto_map": {
10
- "AutoConfig": "configuration_hubert_spkreg.HubertSpkRegConfig"
 
11
  },
12
  "bos_token_id": 1,
13
  "classifier_proj_size": 256,
@@ -78,6 +79,7 @@
78
  "reduction": "mean",
79
  "scale": 30.0,
80
  "tokenizer_class": "Wav2Vec2CTCTokenizer",
 
81
  "transformers_version": "4.46.2",
82
  "use_weighted_layer_sum": false,
83
  "vocab_size": 32
 
3
  "activation_dropout": 0.1,
4
  "apply_spec_augment": true,
5
  "architectures": [
6
+ "HubertSpkRegModel"
7
  ],
8
  "attention_dropout": 0.1,
9
  "auto_map": {
10
+ "AutoConfig": "configuration_hubert_spkreg.HubertSpkRegConfig",
11
+ "AutoModel": "modeling_hubert_spkreg.HubertSpkRegModel"
12
  },
13
  "bos_token_id": 1,
14
  "classifier_proj_size": 256,
 
79
  "reduction": "mean",
80
  "scale": 30.0,
81
  "tokenizer_class": "Wav2Vec2CTCTokenizer",
82
+ "torch_dtype": "float32",
83
  "transformers_version": "4.46.2",
84
  "use_weighted_layer_sum": false,
85
  "vocab_size": 32
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:411e8e7f967ba2a68bc6fba072e6374effc390225c7fdb75b8731edd95717e15
3
+ size 377510584
modeling_hubert_spkreg.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from typing import Union, Tuple, Optional
4
+
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutput
13
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
14
+ from transformers.integrations.fsdp import is_fsdp_managed_module
15
+ from transformers.models.hubert.modeling_hubert import (
16
+ HubertFeatureEncoder,
17
+ HubertFeatureProjection,
18
+ HubertEncoderStableLayerNorm,
19
+ HubertEncoder,
20
+ _HIDDEN_STATES_START_POSITION
21
+ )
22
+
23
+ from .configuration_hubert_spkreg import HubertSpkRegConfig
24
+
25
+
26
+ class HubertSpkRegPreTrainedModel(PreTrainedModel):
27
+ """
28
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
29
+ models.
30
+ """
31
+
32
+ config_class = HubertSpkRegConfig
33
+ base_model_prefix = "hubert"
34
+ main_input_name = "input_values"
35
+ supports_gradient_checkpointing = True
36
+ _supports_flash_attn_2 = True
37
+ _supports_sdpa = True
38
+
39
+ def _init_weights(self, module):
40
+ """Initialize the weights"""
41
+ if isinstance(module, nn.Linear):
42
+ # Slightly different from the TF version which uses truncated_normal for initialization
43
+ # cf https://github.com/pytorch/pytorch/pull/5617
44
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
45
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
46
+ module.bias.data.zero_()
47
+ module.weight.data.fill_(1.0)
48
+ elif isinstance(module, nn.Conv1d):
49
+ if is_deepspeed_zero3_enabled():
50
+ import deepspeed
51
+
52
+ if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
53
+ with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
54
+ nn.init.kaiming_normal_(module.weight.data)
55
+ else:
56
+ with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
57
+ nn.init.kaiming_normal_(module.weight.data)
58
+ else:
59
+ nn.init.kaiming_normal_(module.weight.data)
60
+
61
+ if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
62
+ module.bias.data.zero_()
63
+
64
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
65
+ """
66
+ Computes the output length of the convolutional layers
67
+ """
68
+
69
+ def _conv_out_length(input_length, kernel_size, stride):
70
+ # 1D convolutional layer output length formula taken
71
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
72
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
73
+
74
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
75
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
76
+
77
+ return input_lengths
78
+
79
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
80
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
81
+ batch_size = attention_mask.shape[0]
82
+
83
+ attention_mask = torch.zeros(
84
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
85
+ )
86
+ # these two operations makes sure that all values before the output lengths idxs are attended to
87
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
88
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
89
+ return attention_mask
90
+
91
+
92
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
93
+ def _compute_mask_indices(
94
+ shape: Tuple[int, int],
95
+ mask_prob: float,
96
+ mask_length: int,
97
+ attention_mask: Optional[torch.LongTensor] = None,
98
+ min_masks: int = 0,
99
+ ) -> np.ndarray:
100
+ """
101
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
102
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
103
+ CPU as part of the preprocessing during training.
104
+
105
+ Args:
106
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
107
+ the first element is the batch size and the second element is the length of the axis to span.
108
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
109
+ independently generated mask spans of length `mask_length` is computed by
110
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
111
+ actual percentage will be smaller.
112
+ mask_length: size of the mask
113
+ min_masks: minimum number of masked spans
114
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
115
+ each batch dimension.
116
+ """
117
+ batch_size, sequence_length = shape
118
+
119
+ if mask_length < 1:
120
+ raise ValueError("`mask_length` has to be bigger than 0.")
121
+
122
+ if mask_length > sequence_length:
123
+ raise ValueError(
124
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
125
+ f" and `sequence_length`: {sequence_length}`"
126
+ )
127
+
128
+ # epsilon is used for probabilistic rounding
129
+ epsilon = np.random.rand(1).item()
130
+
131
+ def compute_num_masked_span(input_length):
132
+ """Given input length, compute how many spans should be masked"""
133
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
134
+ num_masked_span = max(num_masked_span, min_masks)
135
+
136
+ # make sure num masked span <= sequence_length
137
+ if num_masked_span * mask_length > sequence_length:
138
+ num_masked_span = sequence_length // mask_length
139
+
140
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
141
+ if input_length - (mask_length - 1) < num_masked_span:
142
+ num_masked_span = max(input_length - (mask_length - 1), 0)
143
+
144
+ return num_masked_span
145
+
146
+ # compute number of masked spans in batch
147
+ input_lengths = (
148
+ attention_mask.sum(-1).detach().tolist()
149
+ if attention_mask is not None
150
+ else [sequence_length for _ in range(batch_size)]
151
+ )
152
+
153
+ # SpecAugment mask to fill
154
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
155
+ spec_aug_mask_idxs = []
156
+
157
+ max_num_masked_span = compute_num_masked_span(sequence_length)
158
+
159
+ if max_num_masked_span == 0:
160
+ return spec_aug_mask
161
+
162
+ for input_length in input_lengths:
163
+ # compute num of masked spans for this input
164
+ num_masked_span = compute_num_masked_span(input_length)
165
+
166
+ # get random indices to mask
167
+ spec_aug_mask_idx = np.random.choice(
168
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
169
+ )
170
+
171
+ # pick first sampled index that will serve as a dummy index to pad vector
172
+ # to ensure same dimension for all batches due to probabilistic rounding
173
+ # Picking first sample just pads those vectors twice.
174
+ if len(spec_aug_mask_idx) == 0:
175
+ # this case can only happen if `input_length` is strictly smaller then
176
+ # `sequence_length` in which case the last token has to be a padding
177
+ # token which we can use as a dummy mask id
178
+ dummy_mask_idx = sequence_length - 1
179
+ else:
180
+ dummy_mask_idx = spec_aug_mask_idx[0]
181
+
182
+ spec_aug_mask_idx = np.concatenate(
183
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
184
+ )
185
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
186
+
187
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
188
+
189
+ # expand masked indices to masked spans
190
+ spec_aug_mask_idxs = np.broadcast_to(
191
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
192
+ )
193
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
194
+
195
+ # add offset to the starting indexes so that indexes now create a span
196
+ offsets = np.arange(mask_length)[None, None, :]
197
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
198
+ batch_size, max_num_masked_span * mask_length
199
+ )
200
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
201
+
202
+ # ensure that we cannot have indices larger than sequence_length
203
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
204
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
205
+
206
+ # scatter indices to mask
207
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
208
+
209
+ return spec_aug_mask
210
+
211
+
212
+ class HubertSpkRegModel(HubertSpkRegPreTrainedModel):
213
+
214
+ def __init__(self, config: HubertSpkRegConfig):
215
+ super().__init__(config)
216
+ self.config = config
217
+ self.feature_extractor = HubertFeatureEncoder(config)
218
+ self.feature_projection = HubertFeatureProjection(config)
219
+
220
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
221
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
222
+
223
+ if config.do_stable_layer_norm:
224
+ self.encoder = HubertEncoderStableLayerNorm(config)
225
+ else:
226
+ self.encoder = HubertEncoder(config)
227
+
228
+ # Initialize weights and apply final processing
229
+ self.post_init()
230
+
231
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
232
+ def _mask_hidden_states(
233
+ self,
234
+ hidden_states: torch.FloatTensor,
235
+ mask_time_indices: Optional[torch.FloatTensor] = None,
236
+ attention_mask: Optional[torch.LongTensor] = None,
237
+ ):
238
+ """
239
+ Masks extracted features along time axis and/or along feature axis according to
240
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
241
+ """
242
+
243
+ # `config.apply_spec_augment` can set masking to False
244
+ if not getattr(self.config, "apply_spec_augment", True):
245
+ return hidden_states
246
+
247
+ # generate indices & apply SpecAugment along time axis
248
+ batch_size, sequence_length, hidden_size = hidden_states.size()
249
+
250
+ if mask_time_indices is not None:
251
+ # apply SpecAugment along time axis with given mask_time_indices
252
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
253
+ elif self.config.mask_time_prob > 0 and self.training:
254
+ mask_time_indices = _compute_mask_indices(
255
+ (batch_size, sequence_length),
256
+ mask_prob=self.config.mask_time_prob,
257
+ mask_length=self.config.mask_time_length,
258
+ attention_mask=attention_mask,
259
+ min_masks=self.config.mask_time_min_masks,
260
+ )
261
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
262
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
263
+
264
+ if self.config.mask_feature_prob > 0 and self.training:
265
+ # generate indices & apply SpecAugment along feature axis
266
+ mask_feature_indices = _compute_mask_indices(
267
+ (batch_size, hidden_size),
268
+ mask_prob=self.config.mask_feature_prob,
269
+ mask_length=self.config.mask_feature_length,
270
+ min_masks=self.config.mask_feature_min_masks,
271
+ )
272
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
273
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
274
+ hidden_states[mask_feature_indices] = 0
275
+
276
+ return hidden_states
277
+
278
+ def forward(
279
+ self,
280
+ input_values: Optional[torch.Tensor],
281
+ attention_mask: Optional[torch.Tensor] = None,
282
+ mask_time_indices: Optional[torch.FloatTensor] = None,
283
+ output_attentions: Optional[bool] = None,
284
+ output_hidden_states: Optional[bool] = None,
285
+ return_dict: Optional[bool] = None,
286
+ ) -> Union[Tuple, BaseModelOutput]:
287
+ """
288
+
289
+ Returns:
290
+
291
+ Example:
292
+
293
+ ```python
294
+ >>> from transformers import AutoProcessor, HubertModel
295
+ >>> from datasets import load_dataset
296
+ >>> import soundfile as sf
297
+
298
+ >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
299
+ >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
300
+
301
+
302
+ >>> def map_to_array(batch):
303
+ ... speech, _ = sf.read(batch["file"])
304
+ ... batch["speech"] = speech
305
+ ... return batch
306
+
307
+
308
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
309
+ >>> ds = ds.map(map_to_array)
310
+
311
+ >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
312
+ >>> hidden_states = model(input_values).last_hidden_state
313
+ ```"""
314
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
315
+ output_hidden_states = (
316
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
317
+ )
318
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
319
+
320
+ extract_features = self.feature_extractor(input_values)
321
+ extract_features = extract_features.transpose(1, 2)
322
+
323
+ if attention_mask is not None:
324
+ # compute reduced attention_mask corresponding to feature vectors
325
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
326
+
327
+ hidden_states = self.feature_projection(extract_features)
328
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
329
+
330
+ encoder_outputs = self.encoder(
331
+ hidden_states,
332
+ attention_mask=attention_mask,
333
+ output_attentions=output_attentions,
334
+ output_hidden_states=output_hidden_states,
335
+ return_dict=return_dict,
336
+ )
337
+
338
+ hidden_states = encoder_outputs[0]
339
+
340
+ if not return_dict:
341
+ return (hidden_states,) + encoder_outputs[1:]
342
+
343
+ return BaseModelOutput(
344
+ last_hidden_state=hidden_states,
345
+ hidden_states=encoder_outputs.hidden_states,
346
+ attentions=encoder_outputs.attentions,
347
+ )
348
+
349
+
350
+ class AngularLinear(nn.Module):
351
+
352
+ def __init__(self, in_features: int, out_features: int):
353
+ super(AngularLinear, self).__init__()
354
+ self.in_features = in_features
355
+ self.out_features = out_features
356
+ self.weight = torch.nn.Parameter(
357
+ torch.FloatTensor(out_features, in_features), requires_grad=True
358
+ )
359
+ nn.init.xavier_normal_(self.weight, gain=1)
360
+
361
+ def forward(
362
+ self,
363
+ inputs: torch.Tensor,
364
+ ):
365
+ # Calculation of cos(theta)
366
+ cosine = F.linear(F.normalize(inputs), F.normalize(self.weight))
367
+ return cosine
368
+
369
+ def extra_repr(self) -> str:
370
+ return 'in_features={}, out_features={}'.format(
371
+ self.in_features, self.out_features
372
+ )
373
+
374
+
375
+ class AMSoftmaxLoss(nn.Module):
376
+ """Additive Margin Softmax (CosFace).
377
+
378
+ Paper: Wang, Feng, et al. "Additive margin softmax for face verification."
379
+ IEEE Signal Processing Letters 25.7 (2018): 926-930.
380
+ """
381
+ def __init__(
382
+ self,
383
+ scale: float = 30.0,
384
+ margin: float = 0.35,
385
+ label_smoothing: float = 0.0,
386
+ reduction: str = "mean"
387
+ ):
388
+ """
389
+ Args:
390
+ num_classes: Number of classes (output dimension)
391
+ scale: Scaling factor for logits (default: 30.0)
392
+ margin: Angular margin (default: 0.35)
393
+ """
394
+ super(AMSoftmaxLoss, self).__init__()
395
+ self.scale = scale
396
+ self.margin = margin
397
+ self.label_smoothing = label_smoothing
398
+ self.reduction = reduction
399
+
400
+ def forward(
401
+ self,
402
+ inputs: torch.Tensor,
403
+ targets: torch.Tensor,
404
+ ):
405
+ """
406
+ Args:
407
+ inputs: Input features of shape (batch_size, num_labels)
408
+ targets: Ground truth labels of shape (batch_size)
409
+ label_smoothing: Label smoothing factor (default: 0.0)
410
+ reduction: Reduction method (default: "mean")
411
+ Returns:
412
+ Loss value
413
+ """
414
+ _, num_labels = inputs.shape
415
+ # `inputs` are the outputs from AngularLinear()
416
+ cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
417
+ psi = cos_theta - self.margin
418
+ one_hot = nn.functional.one_hot(targets, num_labels)
419
+ outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
420
+ loss = F.cross_entropy(
421
+ outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
422
+ )
423
+ return loss
424
+
425
+
426
+ class AAMSoftmaxLoss(nn.Module):
427
+ """Additive Angular Margin Softmax (ArcFace).
428
+
429
+ Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition."
430
+ Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
431
+ """
432
+ def __init__(
433
+ self,
434
+ scale: float = 30.0,
435
+ margin: float = 0.35,
436
+ easy_margin: bool = False,
437
+ label_smoothing: float = 0.0,
438
+ reduction: str = "mean"
439
+ ):
440
+ """
441
+ Args:
442
+ num_classes: Number of classes (output dimension)
443
+ scale: Scaling factor for logits (default: 30.0)
444
+ margin: Angular margin (default: 0.35)
445
+ easy_margin: Use the easy margin loss (default: False)
446
+ """
447
+ super(AAMSoftmaxLoss, self).__init__()
448
+ self.scale = scale
449
+ self.margin = margin
450
+ self.easy_margin = easy_margin
451
+ self.label_smoothing = label_smoothing
452
+ self.reduction = reduction
453
+
454
+ def forward(
455
+ self,
456
+ inputs: torch.Tensor,
457
+ targets: torch.Tensor,
458
+ ):
459
+ """
460
+ Args:
461
+ inputs: Input features of shape (batch_size, num_labels)
462
+ targets: Ground truth labels of shape (batch_size)
463
+ Returns:
464
+ Loss value
465
+ """
466
+ _, num_labels = inputs.shape
467
+ # `inputs` are the outputs from AngularLinear()
468
+ cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
469
+ theta = torch.acos(cos_theta)
470
+ psi = torch.cos(theta + self.margin)
471
+ one_hot = nn.functional.one_hot(targets, num_labels)
472
+ outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
473
+ loss = F.cross_entropy(
474
+ outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
475
+ )
476
+ return loss
477
+
478
+
479
+ class HubertSpkRegForSequenceClassification(HubertSpkRegPreTrainedModel):
480
+ def __init__(self, config):
481
+ super().__init__(config)
482
+
483
+ if hasattr(config, "add_adapter") and config.add_adapter:
484
+ raise ValueError(
485
+ "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)"
486
+ )
487
+ self.hubert = HubertSpkRegModel(config)
488
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
489
+ if config.use_weighted_layer_sum:
490
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
491
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
492
+
493
+ if self.config.loss_fct == 'cross_entropy':
494
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
495
+ elif self.config.loss_fct == 'additive_margin':
496
+ self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
497
+ elif self.config.loss_fct == 'additive_angular_margin':
498
+ self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
499
+ else:
500
+ raise ValueError(f"Unsupported loss function: {self.config.loss_fct}")
501
+
502
+ # Initialize weights and apply final processing
503
+ self.post_init()
504
+
505
+ def freeze_feature_extractor(self):
506
+ """
507
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
508
+ not be updated during training.
509
+ """
510
+ warnings.warn(
511
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
512
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
513
+ FutureWarning,
514
+ )
515
+ self.freeze_feature_encoder()
516
+
517
+ def freeze_feature_encoder(self):
518
+ """
519
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
520
+ not be updated during training.
521
+ """
522
+ self.hubert.feature_extractor._freeze_parameters()
523
+
524
+ def freeze_base_model(self):
525
+ """
526
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
527
+ be updated during training. Only the classification head will be updated.
528
+ """
529
+ for param in self.hubert.parameters():
530
+ param.requires_grad = False
531
+
532
+ def forward(
533
+ self,
534
+ input_values: Optional[torch.Tensor],
535
+ attention_mask: Optional[torch.Tensor] = None,
536
+ output_attentions: Optional[bool] = None,
537
+ output_hidden_states: Optional[bool] = None,
538
+ return_dict: Optional[bool] = None,
539
+ labels: Optional[torch.Tensor] = None,
540
+ ) -> Union[Tuple, SequenceClassifierOutput]:
541
+ r"""
542
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
543
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
544
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
545
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
546
+ """
547
+
548
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
549
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
550
+
551
+ outputs = self.hubert(
552
+ input_values,
553
+ attention_mask=attention_mask,
554
+ output_attentions=output_attentions,
555
+ output_hidden_states=output_hidden_states,
556
+ return_dict=return_dict,
557
+ )
558
+
559
+ if self.config.use_weighted_layer_sum:
560
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
561
+ hidden_states = torch.stack(hidden_states, dim=1)
562
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
563
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
564
+ else:
565
+ hidden_states = outputs[0]
566
+
567
+ hidden_states = self.projector(hidden_states)
568
+ if attention_mask is None:
569
+ pooled_output = hidden_states.mean(dim=1)
570
+ else:
571
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
572
+ hidden_states[~padding_mask] = 0.0
573
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
574
+
575
+ logits = self.classifier(pooled_output)
576
+
577
+ loss = None
578
+ if labels is not None:
579
+ if self.config.loss_fct == 'cross_entropy':
580
+ loss_fct = nn.CrossEntropyLoss(
581
+ label_smoothing=self.config.label_smoothing,
582
+ reduction=self.config.reduction
583
+ )
584
+ elif self.config.loss_fct == 'additive_margin':
585
+ loss_fct = AMSoftmaxLoss(
586
+ scale=self.config.scale,
587
+ margin=self.config.margin,
588
+ label_smoothing=self.config.label_smoothing,
589
+ reduction=self.config.reduction
590
+ )
591
+ elif self.config.loss_fct == 'additive_angular_margin':
592
+ loss_fct = AAMSoftmaxLoss(
593
+ scale=self.config.scale,
594
+ margin=self.config.margin,
595
+ easy_margin=self.config.easy_margin,
596
+ label_smoothing=self.config.label_smoothing,
597
+ reduction=self.config.reduction
598
+ )
599
+ loss = loss_fct(
600
+ logits.view(-1, self.config.num_labels),
601
+ labels.view(-1),
602
+ )
603
+
604
+ if not return_dict:
605
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
606
+ return ((loss,) + output) if loss is not None else output
607
+
608
+ return SequenceClassifierOutput(
609
+ loss=loss,
610
+ logits=logits,
611
+ hidden_states=outputs.hidden_states,
612
+ attentions=outputs.attentions,
613
+ )