igorktech commited on
Commit
77b797c
·
1 Parent(s): eb37430

Delete padding mask in MHA

Browse files

Padding already in att mask

Files changed (1) hide show
  1. modelling_hier.py +0 -2
modelling_hier.py CHANGED
@@ -339,7 +339,6 @@ class HierBert(Module):
339
  # Shared encoders or Segment-wise encoders
340
  # print("SWE")
341
  enc_inp, att_w = layer(enc_inp,
342
- src_key_padding_mask=src_key_padding_mask,
343
  src_mask=enc_mask_utt.repeat(self.config.num_attention_heads, 1, 1))
344
  else:
345
  # Positional Embedding for Context Encoder if few connected CSE use it before
@@ -347,7 +346,6 @@ class HierBert(Module):
347
  # Context encoder or Cross-segment encoders
348
  # print("CSE")
349
  enc_inp, att_w = layer(enc_inp,
350
- src_key_padding_mask=src_key_padding_mask,
351
  src_mask=enc_mask_ct.repeat(self.config.num_attention_heads, 1, 1))
352
  if output_attentions:
353
  all_self_attentions = all_self_attentions + (att_w,)
 
339
  # Shared encoders or Segment-wise encoders
340
  # print("SWE")
341
  enc_inp, att_w = layer(enc_inp,
 
342
  src_mask=enc_mask_utt.repeat(self.config.num_attention_heads, 1, 1))
343
  else:
344
  # Positional Embedding for Context Encoder if few connected CSE use it before
 
346
  # Context encoder or Cross-segment encoders
347
  # print("CSE")
348
  enc_inp, att_w = layer(enc_inp,
 
349
  src_mask=enc_mask_ct.repeat(self.config.num_attention_heads, 1, 1))
350
  if output_attentions:
351
  all_self_attentions = all_self_attentions + (att_w,)