Update bert_layers.py
Browse files- bert_layers.py +30 -23
bert_layers.py
CHANGED
@@ -203,7 +203,7 @@ class BertUnpadSelfAttention(nn.Module):
|
|
203 |
print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
|
204 |
except:
|
205 |
print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
|
206 |
-
return rearrange(attention, 'nnz h d -> nnz (h d)')
|
207 |
|
208 |
|
209 |
# Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
|
@@ -254,7 +254,7 @@ class BertUnpadAttention(nn.Module):
|
|
254 |
attn_mask: None or (batch, max_seqlen_in_batch)
|
255 |
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
256 |
"""
|
257 |
-
self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
|
258 |
attn_mask, bias)
|
259 |
|
260 |
try:
|
@@ -266,7 +266,7 @@ class BertUnpadAttention(nn.Module):
|
|
266 |
return self.output(index_first_axis(self_output, subset_idx),
|
267 |
index_first_axis(input_tensor, subset_idx))
|
268 |
else:
|
269 |
-
return self.output(self_output, input_tensor)
|
270 |
|
271 |
|
272 |
class BertGatedLinearUnitMLP(nn.Module):
|
@@ -347,12 +347,12 @@ class BertLayer(nn.Module):
|
|
347 |
attn_mask: None or (batch, max_seqlen_in_batch)
|
348 |
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
349 |
"""
|
350 |
-
attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
|
351 |
subset_idx, indices, attn_mask, bias)
|
352 |
print(f'BertLayer attention_output shape: {attention_output.shape}')
|
353 |
layer_output = self.mlp(attention_output)
|
354 |
print(f'BertLayer layer_output shape: {layer_output.shape}')
|
355 |
-
return layer_output, attention_output # JAANDOUI: this only returns layer_output in the original work.
|
356 |
|
357 |
|
358 |
class BertEncoder(nn.Module):
|
@@ -467,11 +467,12 @@ class BertEncoder(nn.Module):
|
|
467 |
|
468 |
all_encoder_layers = []
|
469 |
all_attention_weights = [] # List to store attention weights
|
|
|
470 |
|
471 |
if subset_mask is None:
|
472 |
for layer_module in self.layer:
|
473 |
# JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
|
474 |
-
hidden_states, attention_weights = layer_module(hidden_states,
|
475 |
cu_seqlens,
|
476 |
seqlen,
|
477 |
None,
|
@@ -482,6 +483,8 @@ class BertEncoder(nn.Module):
|
|
482 |
# print(f'Inner Attention: {attention_weights}')
|
483 |
print(f'Inner Attention shape: {attention_weights.shape}')
|
484 |
all_attention_weights.append(attention_weights) # Store attention weights
|
|
|
|
|
485 |
if output_all_encoded_layers:
|
486 |
all_encoder_layers.append(hidden_states)
|
487 |
# Pad inputs and mask. It will insert back zero-padded tokens.
|
@@ -494,7 +497,7 @@ class BertEncoder(nn.Module):
|
|
494 |
for i in range(len(self.layer) - 1):
|
495 |
layer_module = self.layer[i]
|
496 |
# JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
|
497 |
-
hidden_states, attention_weights = layer_module(hidden_states,
|
498 |
cu_seqlens,
|
499 |
seqlen,
|
500 |
None,
|
@@ -502,12 +505,14 @@ class BertEncoder(nn.Module):
|
|
502 |
attn_mask=attention_mask,
|
503 |
bias=alibi_attn_mask)
|
504 |
all_attention_weights.append(attention_weights) # JAANDOUI: Store attention weights
|
|
|
|
|
505 |
if output_all_encoded_layers:
|
506 |
all_encoder_layers.append(hidden_states)
|
507 |
subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
|
508 |
as_tuple=False).flatten()
|
509 |
# JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
|
510 |
-
hidden_states, attention_weights = self.layer[-1](hidden_states,
|
511 |
cu_seqlens,
|
512 |
seqlen,
|
513 |
subset_idx=subset_idx,
|
@@ -515,6 +520,8 @@ class BertEncoder(nn.Module):
|
|
515 |
attn_mask=attention_mask,
|
516 |
bias=alibi_attn_mask)
|
517 |
all_attention_weights.append(attention_weights) # JAANDOUI: appending the attention of different layers together.
|
|
|
|
|
518 |
# print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
|
519 |
print(f'and this is the [0]shape inside encoder: \n {all_attention_weights[0].shape}')
|
520 |
# print(f'NUMBER6: {all_attention_weights}')
|
@@ -522,7 +529,7 @@ class BertEncoder(nn.Module):
|
|
522 |
all_encoder_layers.append(hidden_states)
|
523 |
|
524 |
# JAANDOUI: Since we now return both, we need to handle them wherever BertEncoder forward is called.
|
525 |
-
return all_encoder_layers, all_attention_weights # Return both hidden states and attention weights
|
526 |
# return all_encoder_layers # JAANDOUI: original return.
|
527 |
|
528 |
|
@@ -649,7 +656,7 @@ class BertModel(BertPreTrainedModel):
|
|
649 |
|
650 |
# JAANDOUI: first part where we call self.encoder (which is the instance of BertEncoder defined here)
|
651 |
# JAANDOUI: need to return the attention weights here too.
|
652 |
-
encoder_outputs, all_attention_weights = self.encoder(
|
653 |
embedding_output,
|
654 |
attention_mask,
|
655 |
output_all_encoded_layers=output_all_encoded_layers,
|
@@ -681,11 +688,11 @@ class BertModel(BertPreTrainedModel):
|
|
681 |
# JAANDOUI: returning all_attention_weights too
|
682 |
if self.pooler is not None:
|
683 |
# print(f'NUMBER8: {all_attention_weights}')
|
684 |
-
return encoder_outputs, pooled_output, all_attention_weights
|
685 |
|
686 |
# JAANDOUI: returning all_attention_weights too
|
687 |
# print(f'NUMBER9: {all_attention_weights}')
|
688 |
-
return encoder_outputs, None, all_attention_weights
|
689 |
# JAANDOUI: need to handle the returned elements wherever BertModel is instantiated.
|
690 |
|
691 |
###################
|
@@ -924,27 +931,27 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
924 |
pooled_output = outputs[1]
|
925 |
|
926 |
try:
|
927 |
-
print(f'outputs[2] before reassignment SHAPE: {outputs[
|
928 |
except:
|
929 |
-
print(print(f'outputs[2] before reassignment LENGTH: {len(outputs[
|
930 |
|
931 |
# JAANDOUI:
|
932 |
-
|
933 |
|
934 |
try:
|
935 |
-
print(f'outputs[2] AFTER reassignment SHAPE: {outputs[
|
936 |
except:
|
937 |
-
print(print(f'outputs[2] AFTER reassignment LENGTH: {len(outputs[
|
938 |
|
939 |
|
940 |
|
941 |
try:
|
942 |
-
print(f'all_attention_weights last: {
|
943 |
except:
|
944 |
try:
|
945 |
-
print(f'last first except: {
|
946 |
except:
|
947 |
-
print(f'last second except: {len(
|
948 |
|
949 |
|
950 |
pooled_output = self.dropout(pooled_output)
|
@@ -984,9 +991,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
984 |
|
985 |
# print(outputs.attentions)
|
986 |
try:
|
987 |
-
print(f'not stacked final attention SHAPE: {outputs[
|
988 |
except:
|
989 |
-
print(f'not stacked final attention LEN: {len(outputs[
|
990 |
|
991 |
# try:
|
992 |
# print(f'STACKED final attention SHAPE: {(outputs.attentions).shape}')
|
@@ -1002,6 +1009,6 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
1002 |
hidden_states=outputs[0],
|
1003 |
#JAANDOUI: returning all_attention_weights here
|
1004 |
# attentions=torch.stack(outputs[2], dim=0),
|
1005 |
-
attentions=outputs[
|
1006 |
)
|
1007 |
|
|
|
203 |
print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
|
204 |
except:
|
205 |
print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
|
206 |
+
return rearrange(attention, 'nnz h d -> nnz (h d)'), attention_probs
|
207 |
|
208 |
|
209 |
# Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
|
|
|
254 |
attn_mask: None or (batch, max_seqlen_in_batch)
|
255 |
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
256 |
"""
|
257 |
+
self_output, attention_probs = self.self(input_tensor, cu_seqlens, max_s, indices,
|
258 |
attn_mask, bias)
|
259 |
|
260 |
try:
|
|
|
266 |
return self.output(index_first_axis(self_output, subset_idx),
|
267 |
index_first_axis(input_tensor, subset_idx))
|
268 |
else:
|
269 |
+
return self.output(self_output, input_tensor), attention_probs
|
270 |
|
271 |
|
272 |
class BertGatedLinearUnitMLP(nn.Module):
|
|
|
347 |
attn_mask: None or (batch, max_seqlen_in_batch)
|
348 |
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
349 |
"""
|
350 |
+
attention_output, attention_probs = self.attention(hidden_states, cu_seqlens, seqlen,
|
351 |
subset_idx, indices, attn_mask, bias)
|
352 |
print(f'BertLayer attention_output shape: {attention_output.shape}')
|
353 |
layer_output = self.mlp(attention_output)
|
354 |
print(f'BertLayer layer_output shape: {layer_output.shape}')
|
355 |
+
return layer_output, attention_output, attention_probs # JAANDOUI: this only returns layer_output in the original work.
|
356 |
|
357 |
|
358 |
class BertEncoder(nn.Module):
|
|
|
467 |
|
468 |
all_encoder_layers = []
|
469 |
all_attention_weights = [] # List to store attention weights
|
470 |
+
all_attention_probs = []
|
471 |
|
472 |
if subset_mask is None:
|
473 |
for layer_module in self.layer:
|
474 |
# JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
|
475 |
+
hidden_states, attention_weights, attention_probs = layer_module(hidden_states,
|
476 |
cu_seqlens,
|
477 |
seqlen,
|
478 |
None,
|
|
|
483 |
# print(f'Inner Attention: {attention_weights}')
|
484 |
print(f'Inner Attention shape: {attention_weights.shape}')
|
485 |
all_attention_weights.append(attention_weights) # Store attention weights
|
486 |
+
all_attention_probs.append(attention_probs) # Store attention probs
|
487 |
+
|
488 |
if output_all_encoded_layers:
|
489 |
all_encoder_layers.append(hidden_states)
|
490 |
# Pad inputs and mask. It will insert back zero-padded tokens.
|
|
|
497 |
for i in range(len(self.layer) - 1):
|
498 |
layer_module = self.layer[i]
|
499 |
# JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
|
500 |
+
hidden_states, attention_weights, attention_probs = layer_module(hidden_states,
|
501 |
cu_seqlens,
|
502 |
seqlen,
|
503 |
None,
|
|
|
505 |
attn_mask=attention_mask,
|
506 |
bias=alibi_attn_mask)
|
507 |
all_attention_weights.append(attention_weights) # JAANDOUI: Store attention weights
|
508 |
+
all_attention_probs.append(attention_probs) # Store attention probs
|
509 |
+
|
510 |
if output_all_encoded_layers:
|
511 |
all_encoder_layers.append(hidden_states)
|
512 |
subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
|
513 |
as_tuple=False).flatten()
|
514 |
# JAANDOUI: Since we get now attention too, we need to unpack 2 elements instead of 1.
|
515 |
+
hidden_states, attention_weights, attention_probs = self.layer[-1](hidden_states,
|
516 |
cu_seqlens,
|
517 |
seqlen,
|
518 |
subset_idx=subset_idx,
|
|
|
520 |
attn_mask=attention_mask,
|
521 |
bias=alibi_attn_mask)
|
522 |
all_attention_weights.append(attention_weights) # JAANDOUI: appending the attention of different layers together.
|
523 |
+
all_attention_probs.append(attention_probs) # Store attention probs
|
524 |
+
|
525 |
# print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
|
526 |
print(f'and this is the [0]shape inside encoder: \n {all_attention_weights[0].shape}')
|
527 |
# print(f'NUMBER6: {all_attention_weights}')
|
|
|
529 |
all_encoder_layers.append(hidden_states)
|
530 |
|
531 |
# JAANDOUI: Since we now return both, we need to handle them wherever BertEncoder forward is called.
|
532 |
+
return all_encoder_layers, all_attention_weights, all_attention_probs # Return both hidden states and attention weights
|
533 |
# return all_encoder_layers # JAANDOUI: original return.
|
534 |
|
535 |
|
|
|
656 |
|
657 |
# JAANDOUI: first part where we call self.encoder (which is the instance of BertEncoder defined here)
|
658 |
# JAANDOUI: need to return the attention weights here too.
|
659 |
+
encoder_outputs, all_attention_weights, all_attention_probs = self.encoder(
|
660 |
embedding_output,
|
661 |
attention_mask,
|
662 |
output_all_encoded_layers=output_all_encoded_layers,
|
|
|
688 |
# JAANDOUI: returning all_attention_weights too
|
689 |
if self.pooler is not None:
|
690 |
# print(f'NUMBER8: {all_attention_weights}')
|
691 |
+
return encoder_outputs, pooled_output, all_attention_weights, all_attention_probs
|
692 |
|
693 |
# JAANDOUI: returning all_attention_weights too
|
694 |
# print(f'NUMBER9: {all_attention_weights}')
|
695 |
+
return encoder_outputs, None, all_attention_weights, all_attention_probs
|
696 |
# JAANDOUI: need to handle the returned elements wherever BertModel is instantiated.
|
697 |
|
698 |
###################
|
|
|
931 |
pooled_output = outputs[1]
|
932 |
|
933 |
try:
|
934 |
+
print(f'outputs[2] before reassignment SHAPE: {outputs[3][0].shape} ')
|
935 |
except:
|
936 |
+
print(print(f'outputs[2] before reassignment LENGTH: {len(outputs[3][0])} '))
|
937 |
|
938 |
# JAANDOUI:
|
939 |
+
all_attention_probs = outputs[3]
|
940 |
|
941 |
try:
|
942 |
+
print(f'outputs[2] AFTER reassignment SHAPE: {outputs[3][0].shape} ')
|
943 |
except:
|
944 |
+
print(print(f'outputs[2] AFTER reassignment LENGTH: {len(outputs[3][0])} '))
|
945 |
|
946 |
|
947 |
|
948 |
try:
|
949 |
+
print(f'all_attention_weights last: {all_attention_probs.shape}')
|
950 |
except:
|
951 |
try:
|
952 |
+
print(f'last first except: {all_attention_probs[0].shape}')
|
953 |
except:
|
954 |
+
print(f'last second except: {len(all_attention_probs[0])}')
|
955 |
|
956 |
|
957 |
pooled_output = self.dropout(pooled_output)
|
|
|
991 |
|
992 |
# print(outputs.attentions)
|
993 |
try:
|
994 |
+
print(f'not stacked final attention SHAPE: {outputs[3][0].shape}')
|
995 |
except:
|
996 |
+
print(f'not stacked final attention LEN: {len(outputs[3])}')
|
997 |
|
998 |
# try:
|
999 |
# print(f'STACKED final attention SHAPE: {(outputs.attentions).shape}')
|
|
|
1009 |
hidden_states=outputs[0],
|
1010 |
#JAANDOUI: returning all_attention_weights here
|
1011 |
# attentions=torch.stack(outputs[2], dim=0),
|
1012 |
+
attentions=outputs[3], # JAANDOUI TODO: should I stack here ????
|
1013 |
)
|
1014 |
|