Update bert_layers.py
Browse files- bert_layers.py +2 -1
bert_layers.py
CHANGED
@@ -903,6 +903,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
903 |
|
904 |
# JAANDOUI:
|
905 |
all_attention_weights = outputs[2]
|
|
|
906 |
# print(f'last: {all_attention_weights}')
|
907 |
|
908 |
pooled_output = self.dropout(pooled_output)
|
@@ -947,6 +948,6 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
947 |
logits=logits,
|
948 |
hidden_states=outputs[0],
|
949 |
#JAANDOUI: returning all_attention_weights here
|
950 |
-
attentions=outputs[2],
|
951 |
)
|
952 |
|
|
|
903 |
|
904 |
# JAANDOUI:
|
905 |
all_attention_weights = outputs[2]
|
906 |
+
|
907 |
# print(f'last: {all_attention_weights}')
|
908 |
|
909 |
pooled_output = self.dropout(pooled_output)
|
|
|
948 |
logits=logits,
|
949 |
hidden_states=outputs[0],
|
950 |
#JAANDOUI: returning all_attention_weights here
|
951 |
+
attentions=torch.stack(outputs[2], dim=0),
|
952 |
)
|
953 |
|