jaandoui commited on
Commit
0f58b02
·
verified ·
1 Parent(s): 40c6375

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +2 -1
bert_layers.py CHANGED
@@ -903,6 +903,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
903
 
904
  # JAANDOUI:
905
  all_attention_weights = outputs[2]
 
906
  # print(f'last: {all_attention_weights}')
907
 
908
  pooled_output = self.dropout(pooled_output)
@@ -947,6 +948,6 @@ class BertForSequenceClassification(BertPreTrainedModel):
947
  logits=logits,
948
  hidden_states=outputs[0],
949
  #JAANDOUI: returning all_attention_weights here
950
- attentions=outputs[2],
951
  )
952
 
 
903
 
904
  # JAANDOUI:
905
  all_attention_weights = outputs[2]
906
+
907
  # print(f'last: {all_attention_weights}')
908
 
909
  pooled_output = self.dropout(pooled_output)
 
948
  logits=logits,
949
  hidden_states=outputs[0],
950
  #JAANDOUI: returning all_attention_weights here
951
+ attentions=torch.stack(outputs[2], dim=0),
952
  )
953