Update bert_layers.py
Browse files- bert_layers.py +40 -40
bert_layers.py
CHANGED
@@ -169,12 +169,12 @@ class BertUnpadSelfAttention(nn.Module):
|
|
169 |
self.attention_head_size)
|
170 |
attention_scores = attention_scores + bias
|
171 |
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
172 |
-
print(f'BUSA: attention_probs 1 shape: {attention_probs.shape}')
|
173 |
attention_probs = self.dropout(attention_probs)
|
174 |
-
print(f'BUSA: attention_probs 2 shape: {attention_probs.shape}')
|
175 |
attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
|
176 |
3) # b s h d
|
177 |
-
print(f'BUSA: attention shape: {attention.shape}')
|
178 |
else:
|
179 |
# Triton implementation only supports 0 attention dropout
|
180 |
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
|
@@ -185,24 +185,24 @@ class BertUnpadSelfAttention(nn.Module):
|
|
185 |
bias_dtype = bias.dtype
|
186 |
bias = bias.to(torch.float16)
|
187 |
attention = flash_attn_qkvpacked_func(qkv, bias)
|
188 |
-
print(f'BUSA Triton: attention 0 shape: {attention_probs.shape}')
|
189 |
attention = attention.to(orig_dtype)
|
190 |
-
print(f'BUSA Triton: attention 1 shape: {attention_probs.shape}')
|
191 |
bias = bias.to(bias_dtype)
|
192 |
else:
|
193 |
attention = flash_attn_qkvpacked_func(qkv, bias)
|
194 |
-
print(f'BUSA Triton: attention 2 shape: {attention_probs.shape}')
|
195 |
# attn_mask is 1 for attend and 0 for don't
|
196 |
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
197 |
-
print(f'BUSA unpadded final attention shape: {attention_probs.shape}')
|
198 |
-
print(f'ATTENTION: {attention.shape}')
|
199 |
|
200 |
-
print(f'PROBLEM HERE: UNDERSTAND IT!!')
|
201 |
rearranged_attention = rearrange(attention, 'nnz h d -> nnz (h d)')
|
202 |
try:
|
203 |
-
print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
|
204 |
except:
|
205 |
-
print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
|
206 |
return rearrange(attention, 'nnz h d -> nnz (h d)'), attention_probs
|
207 |
|
208 |
|
@@ -257,10 +257,10 @@ class BertUnpadAttention(nn.Module):
|
|
257 |
self_output, attention_probs = self.self(input_tensor, cu_seqlens, max_s, indices,
|
258 |
attn_mask, bias)
|
259 |
|
260 |
-
try:
|
261 |
-
print(f'IMPORTANT: {self_output.shape}')
|
262 |
-
except:
|
263 |
-
print(f'IMPORTANT2: {self_output[0].shape}')
|
264 |
|
265 |
if subset_idx is not None:
|
266 |
return self.output(index_first_axis(self_output, subset_idx),
|
@@ -349,9 +349,9 @@ class BertLayer(nn.Module):
|
|
349 |
"""
|
350 |
attention_output, attention_probs = self.attention(hidden_states, cu_seqlens, seqlen,
|
351 |
subset_idx, indices, attn_mask, bias)
|
352 |
-
print(f'BertLayer attention_output shape: {attention_output.shape}')
|
353 |
layer_output = self.mlp(attention_output)
|
354 |
-
print(f'BertLayer layer_output shape: {layer_output.shape}')
|
355 |
return layer_output, attention_output, attention_probs # JAANDOUI: this only returns layer_output in the original work.
|
356 |
|
357 |
|
@@ -372,7 +372,7 @@ class BertEncoder(nn.Module):
|
|
372 |
[copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
373 |
|
374 |
self.num_attention_heads = config.num_attention_heads
|
375 |
-
print(f'nbr of attention heads: {self.num_attention_heads}')
|
376 |
# The alibi mask will be dynamically expanded if it is too small for
|
377 |
# the input the model receives. But it generally helps to initialize it
|
378 |
# to a reasonably large size to help pre-allocate CUDA memory.
|
@@ -481,7 +481,7 @@ class BertEncoder(nn.Module):
|
|
481 |
bias=alibi_attn_mask)
|
482 |
# JAANDOUI
|
483 |
# print(f'Inner Attention: {attention_weights}')
|
484 |
-
print(f'Inner Attention shape: {attention_weights.shape}')
|
485 |
all_attention_weights.append(attention_weights) # Store attention weights
|
486 |
all_attention_probs.append(attention_probs) # Store attention probs
|
487 |
|
@@ -523,7 +523,7 @@ class BertEncoder(nn.Module):
|
|
523 |
all_attention_probs.append(attention_probs) # Store attention probs
|
524 |
|
525 |
# print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
|
526 |
-
print(f'and this is the [0]shape inside encoder: \n {all_attention_weights[0].shape}')
|
527 |
# print(f'NUMBER6: {all_attention_weights}')
|
528 |
if not output_all_encoded_layers:
|
529 |
all_encoder_layers.append(hidden_states)
|
@@ -663,7 +663,7 @@ class BertModel(BertPreTrainedModel):
|
|
663 |
subset_mask=subset_mask)
|
664 |
# print(f'NUMBER7: {all_attention_weights}')
|
665 |
# print(f'here is the matrix of attentions in BERT: \n {all_attention_weights}')
|
666 |
-
print(f'and this is the [0]shape in BERT: \n {all_attention_weights[0].shape}')
|
667 |
|
668 |
if masked_tokens_mask is None:
|
669 |
sequence_output = encoder_outputs[-1]
|
@@ -930,28 +930,28 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
930 |
|
931 |
pooled_output = outputs[1]
|
932 |
|
933 |
-
try:
|
934 |
-
|
935 |
-
except:
|
936 |
-
|
937 |
|
938 |
# JAANDOUI:
|
939 |
all_attention_probs = outputs[3]
|
940 |
|
941 |
-
try:
|
942 |
-
|
943 |
-
except:
|
944 |
-
|
945 |
|
946 |
|
947 |
|
948 |
-
try:
|
949 |
-
|
950 |
-
except:
|
951 |
-
|
952 |
-
|
953 |
-
|
954 |
-
|
955 |
|
956 |
|
957 |
pooled_output = self.dropout(pooled_output)
|
@@ -990,10 +990,10 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
990 |
return ((loss,) + output) if loss is not None else output
|
991 |
|
992 |
# print(outputs.attentions)
|
993 |
-
try:
|
994 |
-
|
995 |
-
except:
|
996 |
-
|
997 |
|
998 |
# try:
|
999 |
# print(f'STACKED final attention SHAPE: {(outputs.attentions).shape}')
|
|
|
169 |
self.attention_head_size)
|
170 |
attention_scores = attention_scores + bias
|
171 |
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
172 |
+
# print(f'BUSA: attention_probs 1 shape: {attention_probs.shape}')
|
173 |
attention_probs = self.dropout(attention_probs)
|
174 |
+
# print(f'BUSA: attention_probs 2 shape: {attention_probs.shape}')
|
175 |
attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
|
176 |
3) # b s h d
|
177 |
+
# print(f'BUSA: attention shape: {attention.shape}')
|
178 |
else:
|
179 |
# Triton implementation only supports 0 attention dropout
|
180 |
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
|
|
|
185 |
bias_dtype = bias.dtype
|
186 |
bias = bias.to(torch.float16)
|
187 |
attention = flash_attn_qkvpacked_func(qkv, bias)
|
188 |
+
# print(f'BUSA Triton: attention 0 shape: {attention_probs.shape}')
|
189 |
attention = attention.to(orig_dtype)
|
190 |
+
# print(f'BUSA Triton: attention 1 shape: {attention_probs.shape}')
|
191 |
bias = bias.to(bias_dtype)
|
192 |
else:
|
193 |
attention = flash_attn_qkvpacked_func(qkv, bias)
|
194 |
+
# print(f'BUSA Triton: attention 2 shape: {attention_probs.shape}')
|
195 |
# attn_mask is 1 for attend and 0 for don't
|
196 |
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
197 |
+
# print(f'BUSA unpadded final attention shape: {attention_probs.shape}')
|
198 |
+
# print(f'ATTENTION: {attention.shape}')
|
199 |
|
200 |
+
# print(f'PROBLEM HERE: UNDERSTAND IT!!')
|
201 |
rearranged_attention = rearrange(attention, 'nnz h d -> nnz (h d)')
|
202 |
try:
|
203 |
+
# print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
|
204 |
except:
|
205 |
+
# print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
|
206 |
return rearrange(attention, 'nnz h d -> nnz (h d)'), attention_probs
|
207 |
|
208 |
|
|
|
257 |
self_output, attention_probs = self.self(input_tensor, cu_seqlens, max_s, indices,
|
258 |
attn_mask, bias)
|
259 |
|
260 |
+
# try:
|
261 |
+
# print(f'IMPORTANT: {self_output.shape}')
|
262 |
+
# except:
|
263 |
+
# print(f'IMPORTANT2: {self_output[0].shape}')
|
264 |
|
265 |
if subset_idx is not None:
|
266 |
return self.output(index_first_axis(self_output, subset_idx),
|
|
|
349 |
"""
|
350 |
attention_output, attention_probs = self.attention(hidden_states, cu_seqlens, seqlen,
|
351 |
subset_idx, indices, attn_mask, bias)
|
352 |
+
# print(f'BertLayer attention_output shape: {attention_output.shape}')
|
353 |
layer_output = self.mlp(attention_output)
|
354 |
+
# print(f'BertLayer layer_output shape: {layer_output.shape}')
|
355 |
return layer_output, attention_output, attention_probs # JAANDOUI: this only returns layer_output in the original work.
|
356 |
|
357 |
|
|
|
372 |
[copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
373 |
|
374 |
self.num_attention_heads = config.num_attention_heads
|
375 |
+
# print(f'nbr of attention heads: {self.num_attention_heads}')
|
376 |
# The alibi mask will be dynamically expanded if it is too small for
|
377 |
# the input the model receives. But it generally helps to initialize it
|
378 |
# to a reasonably large size to help pre-allocate CUDA memory.
|
|
|
481 |
bias=alibi_attn_mask)
|
482 |
# JAANDOUI
|
483 |
# print(f'Inner Attention: {attention_weights}')
|
484 |
+
# print(f'Inner Attention shape: {attention_weights.shape}')
|
485 |
all_attention_weights.append(attention_weights) # Store attention weights
|
486 |
all_attention_probs.append(attention_probs) # Store attention probs
|
487 |
|
|
|
523 |
all_attention_probs.append(attention_probs) # Store attention probs
|
524 |
|
525 |
# print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
|
526 |
+
# print(f'and this is the [0]shape inside encoder: \n {all_attention_weights[0].shape}')
|
527 |
# print(f'NUMBER6: {all_attention_weights}')
|
528 |
if not output_all_encoded_layers:
|
529 |
all_encoder_layers.append(hidden_states)
|
|
|
663 |
subset_mask=subset_mask)
|
664 |
# print(f'NUMBER7: {all_attention_weights}')
|
665 |
# print(f'here is the matrix of attentions in BERT: \n {all_attention_weights}')
|
666 |
+
# print(f'and this is the [0]shape in BERT: \n {all_attention_weights[0].shape}')
|
667 |
|
668 |
if masked_tokens_mask is None:
|
669 |
sequence_output = encoder_outputs[-1]
|
|
|
930 |
|
931 |
pooled_output = outputs[1]
|
932 |
|
933 |
+
# try:
|
934 |
+
# print(f'outputs[2] before reassignment SHAPE: {outputs[3][0].shape} ')
|
935 |
+
# except:
|
936 |
+
# print(print(f'outputs[2] before reassignment LENGTH: {len(outputs[3][0])} '))
|
937 |
|
938 |
# JAANDOUI:
|
939 |
all_attention_probs = outputs[3]
|
940 |
|
941 |
+
# try:
|
942 |
+
# print(f'outputs[2] AFTER reassignment probsss SHAPE: {outputs[3][0].shape} ')
|
943 |
+
# except:
|
944 |
+
# print(print(f'outputs[2] AFTER reassignment probsss LENGTH: {len(outputs[3][0])} '))
|
945 |
|
946 |
|
947 |
|
948 |
+
# try:
|
949 |
+
# print(f'all_attention_weights probsss last: {all_attention_probs.shape}')
|
950 |
+
# except:
|
951 |
+
# try:
|
952 |
+
# print(f'last first except probsss: {all_attention_probs[0].shape}')
|
953 |
+
# except:
|
954 |
+
# print(f'last second except probsss: {len(all_attention_probs[0])}')
|
955 |
|
956 |
|
957 |
pooled_output = self.dropout(pooled_output)
|
|
|
990 |
return ((loss,) + output) if loss is not None else output
|
991 |
|
992 |
# print(outputs.attentions)
|
993 |
+
# try:
|
994 |
+
# print(f'not stacked final attention probsss SHAPE: {outputs[3][0].shape}')
|
995 |
+
# except:
|
996 |
+
# print(f'not stacked final attention probsss LEN: {len(outputs[3])}')
|
997 |
|
998 |
# try:
|
999 |
# print(f'STACKED final attention SHAPE: {(outputs.attentions).shape}')
|