robinzixuan
commited on
Update modeling_bert.py
Browse files- modeling_bert.py +3 -151
modeling_bert.py
CHANGED
@@ -6,6 +6,7 @@
|
|
6 |
# you may not use this file except in compliance with the License.
|
7 |
|
8 |
# You may obtain a copy of the License at
|
|
|
9 |
#
|
10 |
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
#
|
@@ -411,156 +412,7 @@ class BertSelfAttention(nn.Module):
|
|
411 |
return outputs
|
412 |
|
413 |
|
414 |
-
class BertOutEffHop(nn.Module):
|
415 |
-
def __init__(self, config, position_embedding_type=None):
|
416 |
-
super().__init__()
|
417 |
-
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
418 |
-
raise ValueError(
|
419 |
-
f'''The hidden size ({
|
420 |
-
config.hidden_size}) is not a multiple of the number of attention '''
|
421 |
-
f"heads ({config.num_attention_heads})"
|
422 |
-
)
|
423 |
-
|
424 |
-
self.num_attention_heads = config.num_attention_heads
|
425 |
-
self.attention_head_size = int(
|
426 |
-
config.hidden_size / config.num_attention_heads)
|
427 |
-
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
428 |
-
|
429 |
-
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
430 |
-
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
431 |
-
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
432 |
-
|
433 |
-
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
434 |
-
self.position_embedding_type = position_embedding_type or getattr(
|
435 |
-
config, "position_embedding_type", "absolute"
|
436 |
-
)
|
437 |
-
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
438 |
-
self.max_position_embeddings = config.max_position_embeddings
|
439 |
-
self.distance_embedding = nn.Embedding(
|
440 |
-
2 * config.max_position_embeddings - 1, self.attention_head_size)
|
441 |
-
|
442 |
-
self.is_decoder = config.is_decoder
|
443 |
|
444 |
-
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
445 |
-
new_x_shape = x.size()[
|
446 |
-
:-1] + (self.num_attention_heads, self.attention_head_size)
|
447 |
-
x = x.view(new_x_shape)
|
448 |
-
return x.permute(0, 2, 1, 3)
|
449 |
-
|
450 |
-
def forward(
|
451 |
-
self,
|
452 |
-
hidden_states: torch.Tensor,
|
453 |
-
attention_mask: Optional[torch.FloatTensor] = None,
|
454 |
-
head_mask: Optional[torch.FloatTensor] = None,
|
455 |
-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
456 |
-
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
457 |
-
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
458 |
-
output_attentions: Optional[bool] = False,
|
459 |
-
) -> Tuple[torch.Tensor]:
|
460 |
-
mixed_query_layer = self.query(hidden_states)
|
461 |
-
|
462 |
-
# If this is instantiated as a cross-attention module, the keys
|
463 |
-
# and values come from an encoder; the attention mask needs to be
|
464 |
-
# such that the encoder's padding tokens are not attended to.
|
465 |
-
is_cross_attention = encoder_hidden_states is not None
|
466 |
-
|
467 |
-
if is_cross_attention and past_key_value is not None:
|
468 |
-
# reuse k,v, cross_attentions
|
469 |
-
key_layer = past_key_value[0]
|
470 |
-
value_layer = past_key_value[1]
|
471 |
-
attention_mask = encoder_attention_mask
|
472 |
-
elif is_cross_attention:
|
473 |
-
key_layer = self.transpose_for_scores(
|
474 |
-
self.key(encoder_hidden_states))
|
475 |
-
value_layer = self.transpose_for_scores(
|
476 |
-
self.value(encoder_hidden_states))
|
477 |
-
attention_mask = encoder_attention_mask
|
478 |
-
elif past_key_value is not None:
|
479 |
-
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
480 |
-
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
481 |
-
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
482 |
-
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
483 |
-
else:
|
484 |
-
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
485 |
-
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
486 |
-
|
487 |
-
query_layer = self.transpose_for_scores(mixed_query_layer)
|
488 |
-
|
489 |
-
use_cache = past_key_value is not None
|
490 |
-
if self.is_decoder:
|
491 |
-
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
492 |
-
# Further calls to cross_attention layer can then reuse all cross-attention
|
493 |
-
# key/value_states (first "if" case)
|
494 |
-
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
495 |
-
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
496 |
-
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
497 |
-
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
498 |
-
past_key_value = (key_layer, value_layer)
|
499 |
-
|
500 |
-
# Take the dot product between "query" and "key" to get the raw attention scores.
|
501 |
-
attention_scores = torch.matmul(
|
502 |
-
query_layer, key_layer.transpose(-1, -2))
|
503 |
-
|
504 |
-
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
505 |
-
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
506 |
-
if use_cache:
|
507 |
-
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
508 |
-
-1, 1
|
509 |
-
)
|
510 |
-
else:
|
511 |
-
position_ids_l = torch.arange(
|
512 |
-
query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
513 |
-
position_ids_r = torch.arange(
|
514 |
-
key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
515 |
-
distance = position_ids_l - position_ids_r
|
516 |
-
|
517 |
-
positional_embedding = self.distance_embedding(
|
518 |
-
distance + self.max_position_embeddings - 1)
|
519 |
-
positional_embedding = positional_embedding.to(
|
520 |
-
dtype=query_layer.dtype) # fp16 compatibility
|
521 |
-
|
522 |
-
if self.position_embedding_type == "relative_key":
|
523 |
-
relative_position_scores = torch.einsum(
|
524 |
-
"bhld,lrd->bhlr", query_layer, positional_embedding)
|
525 |
-
attention_scores = attention_scores + relative_position_scores
|
526 |
-
elif self.position_embedding_type == "relative_key_query":
|
527 |
-
relative_position_scores_query = torch.einsum(
|
528 |
-
"bhld,lrd->bhlr", query_layer, positional_embedding)
|
529 |
-
relative_position_scores_key = torch.einsum(
|
530 |
-
"bhrd,lrd->bhlr", key_layer, positional_embedding)
|
531 |
-
attention_scores = attention_scores + \
|
532 |
-
relative_position_scores_query + relative_position_scores_key
|
533 |
-
|
534 |
-
attention_scores = attention_scores / \
|
535 |
-
math.sqrt(self.attention_head_size)
|
536 |
-
if attention_mask is not None:
|
537 |
-
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
538 |
-
attention_scores = attention_scores + attention_mask
|
539 |
-
|
540 |
-
# Normalize the attention scores to probabilities.
|
541 |
-
attention_probs = softmax_1(attention_scores, dim=-1)
|
542 |
-
print(softmax_1)
|
543 |
-
# This is actually dropping out entire tokens to attend to, which might
|
544 |
-
# seem a bit unusual, but is taken from the original Transformer paper.
|
545 |
-
attention_probs = self.dropout(attention_probs)
|
546 |
-
|
547 |
-
# Mask heads if we want to
|
548 |
-
if head_mask is not None:
|
549 |
-
attention_probs = attention_probs * head_mask
|
550 |
-
|
551 |
-
context_layer = torch.matmul(attention_probs, value_layer)
|
552 |
-
|
553 |
-
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
554 |
-
new_context_layer_shape = context_layer.size()[
|
555 |
-
:-2] + (self.all_head_size,)
|
556 |
-
context_layer = context_layer.view(new_context_layer_shape)
|
557 |
-
|
558 |
-
outputs = (context_layer, attention_probs) if output_attentions else (
|
559 |
-
context_layer,)
|
560 |
-
|
561 |
-
if self.is_decoder:
|
562 |
-
outputs = outputs + (past_key_value,)
|
563 |
-
return outputs
|
564 |
|
565 |
|
566 |
class BertSdpaSelfAttention(BertSelfAttention):
|
@@ -684,14 +536,14 @@ class BertSelfOutput(nn.Module):
|
|
684 |
BERT_SELF_ATTENTION_CLASSES = {
|
685 |
"eager": BertSelfAttention,
|
686 |
"sdpa": BertSdpaSelfAttention,
|
687 |
-
|
688 |
}
|
689 |
|
690 |
|
691 |
class BertAttention(nn.Module):
|
692 |
def __init__(self, config, position_embedding_type=None):
|
693 |
super().__init__()
|
694 |
-
self.self =
|
695 |
config, position_embedding_type=position_embedding_type
|
696 |
)
|
697 |
self.output = BertSelfOutput(config)
|
|
|
6 |
# you may not use this file except in compliance with the License.
|
7 |
|
8 |
# You may obtain a copy of the License at
|
9 |
+
|
10 |
#
|
11 |
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
#
|
|
|
412 |
return outputs
|
413 |
|
414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
|
418 |
class BertSdpaSelfAttention(BertSelfAttention):
|
|
|
536 |
BERT_SELF_ATTENTION_CLASSES = {
|
537 |
"eager": BertSelfAttention,
|
538 |
"sdpa": BertSdpaSelfAttention,
|
539 |
+
|
540 |
}
|
541 |
|
542 |
|
543 |
class BertAttention(nn.Module):
|
544 |
def __init__(self, config, position_embedding_type=None):
|
545 |
super().__init__()
|
546 |
+
self.self = BertSelfAttention(
|
547 |
config, position_embedding_type=position_embedding_type
|
548 |
)
|
549 |
self.output = BertSelfOutput(config)
|