Delete padding mask in MHA
Browse filesPadding already in att mask
- modelling_hier.py +0 -2
modelling_hier.py
CHANGED
@@ -339,7 +339,6 @@ class HierBert(Module):
|
|
339 |
# Shared encoders or Segment-wise encoders
|
340 |
# print("SWE")
|
341 |
enc_inp, att_w = layer(enc_inp,
|
342 |
-
src_key_padding_mask=src_key_padding_mask,
|
343 |
src_mask=enc_mask_utt.repeat(self.config.num_attention_heads, 1, 1))
|
344 |
else:
|
345 |
# Positional Embedding for Context Encoder if few connected CSE use it before
|
@@ -347,7 +346,6 @@ class HierBert(Module):
|
|
347 |
# Context encoder or Cross-segment encoders
|
348 |
# print("CSE")
|
349 |
enc_inp, att_w = layer(enc_inp,
|
350 |
-
src_key_padding_mask=src_key_padding_mask,
|
351 |
src_mask=enc_mask_ct.repeat(self.config.num_attention_heads, 1, 1))
|
352 |
if output_attentions:
|
353 |
all_self_attentions = all_self_attentions + (att_w,)
|
|
|
339 |
# Shared encoders or Segment-wise encoders
|
340 |
# print("SWE")
|
341 |
enc_inp, att_w = layer(enc_inp,
|
|
|
342 |
src_mask=enc_mask_utt.repeat(self.config.num_attention_heads, 1, 1))
|
343 |
else:
|
344 |
# Positional Embedding for Context Encoder if few connected CSE use it before
|
|
|
346 |
# Context encoder or Cross-segment encoders
|
347 |
# print("CSE")
|
348 |
enc_inp, att_w = layer(enc_inp,
|
|
|
349 |
src_mask=enc_mask_ct.repeat(self.config.num_attention_heads, 1, 1))
|
350 |
if output_attentions:
|
351 |
all_self_attentions = all_self_attentions + (att_w,)
|