jaandoui commited on
Commit
5532eb0
·
verified ·
1 Parent(s): 4af8bbb

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +30 -23
bert_layers.py CHANGED
@@ -203,7 +203,7 @@ class BertUnpadSelfAttention(nn.Module):
203
  print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
204
  except:
205
  print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
206
- return rearrange(attention, 'nnz h d -> nnz (h d)')
207
 
208
 
209
  # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
@@ -254,7 +254,7 @@ class BertUnpadAttention(nn.Module):
254
  attn_mask: None or (batch, max_seqlen_in_batch)
255
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
256
  """
257
- self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
258
  attn_mask, bias)
259
 
260
  try:
@@ -266,7 +266,7 @@ class BertUnpadAttention(nn.Module):
266
  return self.output(index_first_axis(self_output, subset_idx),
267
  index_first_axis(input_tensor, subset_idx))
268
  else:
269
- return self.output(self_output, input_tensor)
270
 
271
 
272
  class BertGatedLinearUnitMLP(nn.Module):
@@ -347,12 +347,12 @@ class BertLayer(nn.Module):
347
  attn_mask: None or (batch, max_seqlen_in_batch)
348
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
349
  """
350
- attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
351
  subset_idx, indices, attn_mask, bias)
352
  print(f'BertLayer attention_output shape: {attention_output.shape}')
353
  layer_output = self.mlp(attention_output)
354
  print(f'BertLayer layer_output shape: {layer_output.shape}')
355
- return layer_output, attention_output # JAANDOUI: this only returns layer_output in the original work.
356
 
357
 
358
  class BertEncoder(nn.Module):
@@ -467,11 +467,12 @@ class BertEncoder(nn.Module):
467
 
468
  all_encoder_layers = []
469
  all_attention_weights = [] # List to store attention weights
 
470
 
471
  if subset_mask is None:
472
  for layer_module in self.layer:
473
  # JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
474
- hidden_states, attention_weights = layer_module(hidden_states,
475
  cu_seqlens,
476
  seqlen,
477
  None,
@@ -482,6 +483,8 @@ class BertEncoder(nn.Module):
482
  # print(f'Inner Attention: {attention_weights}')
483
  print(f'Inner Attention shape: {attention_weights.shape}')
484
  all_attention_weights.append(attention_weights) # Store attention weights
 
 
485
  if output_all_encoded_layers:
486
  all_encoder_layers.append(hidden_states)
487
  # Pad inputs and mask. It will insert back zero-padded tokens.
@@ -494,7 +497,7 @@ class BertEncoder(nn.Module):
494
  for i in range(len(self.layer) - 1):
495
  layer_module = self.layer[i]
496
  # JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
497
- hidden_states, attention_weights = layer_module(hidden_states,
498
  cu_seqlens,
499
  seqlen,
500
  None,
@@ -502,12 +505,14 @@ class BertEncoder(nn.Module):
502
  attn_mask=attention_mask,
503
  bias=alibi_attn_mask)
504
  all_attention_weights.append(attention_weights) # JAANDOUI: Store attention weights
 
 
505
  if output_all_encoded_layers:
506
  all_encoder_layers.append(hidden_states)
507
  subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
508
  as_tuple=False).flatten()
509
  # JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
510
- hidden_states, attention_weights = self.layer[-1](hidden_states,
511
  cu_seqlens,
512
  seqlen,
513
  subset_idx=subset_idx,
@@ -515,6 +520,8 @@ class BertEncoder(nn.Module):
515
  attn_mask=attention_mask,
516
  bias=alibi_attn_mask)
517
  all_attention_weights.append(attention_weights) # JAANDOUI: appending the attention of different layers together.
 
 
518
  # print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
519
  print(f'and this is the [0]shape inside encoder: \n {all_attention_weights[0].shape}')
520
  # print(f'NUMBER6: {all_attention_weights}')
@@ -522,7 +529,7 @@ class BertEncoder(nn.Module):
522
  all_encoder_layers.append(hidden_states)
523
 
524
  # JAANDOUI: Since we now return both, we need to handle them wherever BertEncoder forward is called.
525
- return all_encoder_layers, all_attention_weights # Return both hidden states and attention weights
526
  # return all_encoder_layers # JAANDOUI: original return.
527
 
528
 
@@ -649,7 +656,7 @@ class BertModel(BertPreTrainedModel):
649
 
650
  # JAANDOUI: first part where we call self.encoder (which is the instance of BertEncoder defined here)
651
  # JAANDOUI: need to return the attention weights here too.
652
- encoder_outputs, all_attention_weights = self.encoder(
653
  embedding_output,
654
  attention_mask,
655
  output_all_encoded_layers=output_all_encoded_layers,
@@ -681,11 +688,11 @@ class BertModel(BertPreTrainedModel):
681
  # JAANDOUI: returning all_attention_weights too
682
  if self.pooler is not None:
683
  # print(f'NUMBER8: {all_attention_weights}')
684
- return encoder_outputs, pooled_output, all_attention_weights
685
 
686
  # JAANDOUI: returning all_attention_weights too
687
  # print(f'NUMBER9: {all_attention_weights}')
688
- return encoder_outputs, None, all_attention_weights
689
  # JAANDOUI: need to handle the returned elements wherever BertModel is instantiated.
690
 
691
  ###################
@@ -924,27 +931,27 @@ class BertForSequenceClassification(BertPreTrainedModel):
924
  pooled_output = outputs[1]
925
 
926
  try:
927
- print(f'outputs[2] before reassignment SHAPE: {outputs[2][0].shape} ')
928
  except:
929
- print(print(f'outputs[2] before reassignment LENGTH: {len(outputs[2][0])} '))
930
 
931
  # JAANDOUI:
932
- all_attention_weights = outputs[2]
933
 
934
  try:
935
- print(f'outputs[2] AFTER reassignment SHAPE: {outputs[2][0].shape} ')
936
  except:
937
- print(print(f'outputs[2] AFTER reassignment LENGTH: {len(outputs[2][0])} '))
938
 
939
 
940
 
941
  try:
942
- print(f'all_attention_weights last: {all_attention_weights.shape}')
943
  except:
944
  try:
945
- print(f'last first except: {all_attention_weights[0].shape}')
946
  except:
947
- print(f'last second except: {len(all_attention_weights[0])}')
948
 
949
 
950
  pooled_output = self.dropout(pooled_output)
@@ -984,9 +991,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
984
 
985
  # print(outputs.attentions)
986
  try:
987
- print(f'not stacked final attention SHAPE: {outputs[2][0].shape}')
988
  except:
989
- print(f'not stacked final attention LEN: {len(outputs[2])}')
990
 
991
  # try:
992
  # print(f'STACKED final attention SHAPE: {(outputs.attentions).shape}')
@@ -1002,6 +1009,6 @@ class BertForSequenceClassification(BertPreTrainedModel):
1002
  hidden_states=outputs[0],
1003
  #JAANDOUI: returning all_attention_weights here
1004
  # attentions=torch.stack(outputs[2], dim=0),
1005
- attentions=outputs[2], # JAANDOUI TODO: should I stack here ????
1006
  )
1007
 
 
203
  print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
204
  except:
205
  print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
206
+ return rearrange(attention, 'nnz h d -> nnz (h d)'), attention_probs
207
 
208
 
209
  # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
 
254
  attn_mask: None or (batch, max_seqlen_in_batch)
255
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
256
  """
257
+ self_output, attention_probs = self.self(input_tensor, cu_seqlens, max_s, indices,
258
  attn_mask, bias)
259
 
260
  try:
 
266
  return self.output(index_first_axis(self_output, subset_idx),
267
  index_first_axis(input_tensor, subset_idx))
268
  else:
269
+ return self.output(self_output, input_tensor), attention_probs
270
 
271
 
272
  class BertGatedLinearUnitMLP(nn.Module):
 
347
  attn_mask: None or (batch, max_seqlen_in_batch)
348
  bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
349
  """
350
+ attention_output, attention_probs = self.attention(hidden_states, cu_seqlens, seqlen,
351
  subset_idx, indices, attn_mask, bias)
352
  print(f'BertLayer attention_output shape: {attention_output.shape}')
353
  layer_output = self.mlp(attention_output)
354
  print(f'BertLayer layer_output shape: {layer_output.shape}')
355
+ return layer_output, attention_output, attention_probs # JAANDOUI: this only returns layer_output in the original work.
356
 
357
 
358
  class BertEncoder(nn.Module):
 
467
 
468
  all_encoder_layers = []
469
  all_attention_weights = [] # List to store attention weights
470
+ all_attention_probs = []
471
 
472
  if subset_mask is None:
473
  for layer_module in self.layer:
474
  # JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
475
+ hidden_states, attention_weights, attention_probs = layer_module(hidden_states,
476
  cu_seqlens,
477
  seqlen,
478
  None,
 
483
  # print(f'Inner Attention: {attention_weights}')
484
  print(f'Inner Attention shape: {attention_weights.shape}')
485
  all_attention_weights.append(attention_weights) # Store attention weights
486
+ all_attention_probs.append(attention_probs) # Store attention probs
487
+
488
  if output_all_encoded_layers:
489
  all_encoder_layers.append(hidden_states)
490
  # Pad inputs and mask. It will insert back zero-padded tokens.
 
497
  for i in range(len(self.layer) - 1):
498
  layer_module = self.layer[i]
499
  # JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
500
+ hidden_states, attention_weights, attention_probs = layer_module(hidden_states,
501
  cu_seqlens,
502
  seqlen,
503
  None,
 
505
  attn_mask=attention_mask,
506
  bias=alibi_attn_mask)
507
  all_attention_weights.append(attention_weights) # JAANDOUI: Store attention weights
508
+ all_attention_probs.append(attention_probs) # Store attention probs
509
+
510
  if output_all_encoded_layers:
511
  all_encoder_layers.append(hidden_states)
512
  subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
513
  as_tuple=False).flatten()
514
  # JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
515
+ hidden_states, attention_weights, attention_probs = self.layer[-1](hidden_states,
516
  cu_seqlens,
517
  seqlen,
518
  subset_idx=subset_idx,
 
520
  attn_mask=attention_mask,
521
  bias=alibi_attn_mask)
522
  all_attention_weights.append(attention_weights) # JAANDOUI: appending the attention of different layers together.
523
+ all_attention_probs.append(attention_probs) # Store attention probs
524
+
525
  # print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
526
  print(f'and this is the [0]shape inside encoder: \n {all_attention_weights[0].shape}')
527
  # print(f'NUMBER6: {all_attention_weights}')
 
529
  all_encoder_layers.append(hidden_states)
530
 
531
  # JAANDOUI: Since we now return both, we need to handle them wherever BertEncoder forward is called.
532
+ return all_encoder_layers, all_attention_weights, all_attention_probs # Return both hidden states and attention weights
533
  # return all_encoder_layers # JAANDOUI: original return.
534
 
535
 
 
656
 
657
  # JAANDOUI: first part where we call self.encoder (which is the instance of BertEncoder defined here)
658
  # JAANDOUI: need to return the attention weights here too.
659
+ encoder_outputs, all_attention_weights, all_attention_probs = self.encoder(
660
  embedding_output,
661
  attention_mask,
662
  output_all_encoded_layers=output_all_encoded_layers,
 
688
  # JAANDOUI: returning all_attention_weights too
689
  if self.pooler is not None:
690
  # print(f'NUMBER8: {all_attention_weights}')
691
+ return encoder_outputs, pooled_output, all_attention_weights, all_attention_probs
692
 
693
  # JAANDOUI: returning all_attention_weights too
694
  # print(f'NUMBER9: {all_attention_weights}')
695
+ return encoder_outputs, None, all_attention_weights, all_attention_probs
696
  # JAANDOUI: need to handle the returned elements wherever BertModel is instantiated.
697
 
698
  ###################
 
931
  pooled_output = outputs[1]
932
 
933
  try:
934
+ print(f'outputs[2] before reassignment SHAPE: {outputs[3][0].shape} ')
935
  except:
936
+ print(print(f'outputs[2] before reassignment LENGTH: {len(outputs[3][0])} '))
937
 
938
  # JAANDOUI:
939
+ all_attention_probs = outputs[3]
940
 
941
  try:
942
+ print(f'outputs[2] AFTER reassignment SHAPE: {outputs[3][0].shape} ')
943
  except:
944
+ print(print(f'outputs[2] AFTER reassignment LENGTH: {len(outputs[3][0])} '))
945
 
946
 
947
 
948
  try:
949
+ print(f'all_attention_weights last: {all_attention_probs.shape}')
950
  except:
951
  try:
952
+ print(f'last first except: {all_attention_probs[0].shape}')
953
  except:
954
+ print(f'last second except: {len(all_attention_probs[0])}')
955
 
956
 
957
  pooled_output = self.dropout(pooled_output)
 
991
 
992
  # print(outputs.attentions)
993
  try:
994
+ print(f'not stacked final attention SHAPE: {outputs[3][0].shape}')
995
  except:
996
+ print(f'not stacked final attention LEN: {len(outputs[3])}')
997
 
998
  # try:
999
  # print(f'STACKED final attention SHAPE: {(outputs.attentions).shape}')
 
1009
  hidden_states=outputs[0],
1010
  #JAANDOUI: returning all_attention_weights here
1011
  # attentions=torch.stack(outputs[2], dim=0),
1012
+ attentions=outputs[3], # JAANDOUI TODO: should I stack here ????
1013
  )
1014