|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy
|
|
import logging
|
|
import math
|
|
import warnings
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from einops import rearrange
|
|
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
|
from transformers.activations import ACT2FN
|
|
from transformers.modeling_outputs import (MaskedLMOutput,
|
|
SequenceClassifierOutput)
|
|
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
from .bert_padding import (index_first_axis,
|
|
index_put_first_axis, pad_input,
|
|
unpad_input, unpad_input_only)
|
|
|
|
try:
|
|
from .flash_attn_triton import flash_attn_qkvpacked_func
|
|
except ImportError as e:
|
|
flash_attn_qkvpacked_func = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BertEmbeddings(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.word_embeddings = nn.Embedding(config.vocab_size,
|
|
config.hidden_size,
|
|
padding_idx=config.pad_token_id)
|
|
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
|
|
config.hidden_size)
|
|
|
|
|
|
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
|
eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.register_buffer('token_type_ids',
|
|
torch.zeros(config.max_position_embeddings,
|
|
dtype=torch.long),
|
|
persistent=False)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
past_key_values_length: int = 0,
|
|
) -> torch.Tensor:
|
|
if (input_ids is not None) == (inputs_embeds is not None):
|
|
raise ValueError('Must specify either input_ids or input_embeds!')
|
|
if input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
else:
|
|
assert inputs_embeds is not None
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
seq_length = input_shape[1]
|
|
|
|
if position_ids is None:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
if token_type_ids is None:
|
|
if hasattr(self, 'token_type_ids'):
|
|
assert isinstance(self.token_type_ids, torch.LongTensor)
|
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
|
|
input_shape[0], seq_length)
|
|
token_type_ids = buffered_token_type_ids_expanded
|
|
else:
|
|
token_type_ids = torch.zeros(input_shape,
|
|
dtype=torch.long,
|
|
device=self.word_embeddings.device)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
|
|
embeddings = inputs_embeds + token_type_embeddings
|
|
|
|
embeddings = self.LayerNorm(embeddings)
|
|
embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
|
|
class BertUnpadSelfAttention(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
|
config, 'embedding_size'):
|
|
raise ValueError(
|
|
f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention '
|
|
f'heads ({config.num_attention_heads})')
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size /
|
|
config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
self.p_dropout = config.attention_probs_dropout_prob
|
|
self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
|
|
|
|
|
|
if flash_attn_qkvpacked_func is None:
|
|
warnings.warn(
|
|
'Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).'
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
|
max_seqlen_in_batch: int, indices: torch.Tensor,
|
|
attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
|
|
"""Perform self-attention.
|
|
|
|
If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
|
|
implementation of self-attention.
|
|
|
|
The arguments are unpadded, and our implementations of attention require padded arguments,
|
|
so we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers.
|
|
The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute.
|
|
It is possible to write an unpadded implementation of attention (in Triton and PyTorch), which we will eventually do.
|
|
|
|
Args:
|
|
hidden_states: (total_nnz, dim)
|
|
cu_seqlens: (batch + 1,)
|
|
max_seqlen_in_batch: int
|
|
indices: (total_nnz,)
|
|
attn_mask: (batch, max_seqlen_in_batch)
|
|
bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
|
|
|
Returns:
|
|
attention: (total_nnz, dim)
|
|
"""
|
|
qkv = self.Wqkv(hidden_states)
|
|
qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1,
|
|
max_seqlen_in_batch)
|
|
qkv = rearrange(qkv,
|
|
'b s (t h d) -> b s t h d',
|
|
t=3,
|
|
h=self.num_attention_heads)
|
|
if self.p_dropout or flash_attn_qkvpacked_func is None:
|
|
|
|
q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3)
|
|
k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1)
|
|
v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3)
|
|
attention_scores = torch.matmul(q, k) / math.sqrt(
|
|
self.attention_head_size)
|
|
attention_scores = attention_scores + bias
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
|
attention_probs = self.dropout(attention_probs)
|
|
attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
|
|
3)
|
|
else:
|
|
|
|
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
|
|
if convert_dtype:
|
|
|
|
orig_dtype = qkv.dtype
|
|
qkv = qkv.to(torch.float16)
|
|
bias_dtype = bias.dtype
|
|
bias = bias.to(torch.float16)
|
|
attention = flash_attn_qkvpacked_func(qkv, bias)
|
|
attention = attention.to(orig_dtype)
|
|
bias = bias.to(bias_dtype)
|
|
else:
|
|
attention = flash_attn_qkvpacked_func(qkv, bias)
|
|
|
|
|
|
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
|
return rearrange(attention, 'nnz h d -> nnz (h d)')
|
|
|
|
|
|
|
|
class BertSelfOutput(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
|
eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states: torch.Tensor,
|
|
input_tensor: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
class BertUnpadAttention(nn.Module):
|
|
"""Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.self = BertUnpadSelfAttention(config)
|
|
self.output = BertSelfOutput(config)
|
|
|
|
def forward(
|
|
self,
|
|
input_tensor: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
max_s: int,
|
|
subset_idx: Optional[torch.Tensor] = None,
|
|
indices: Optional[torch.Tensor] = None,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass for scaled self-attention without padding.
|
|
|
|
Arguments:
|
|
input_tensor: (total_nnz, dim)
|
|
cu_seqlens: (batch + 1,)
|
|
max_s: int
|
|
subset_idx: () set of indices whose values we care about at the end of the layer
|
|
(e.g., the masked tokens, if this is the final layer).
|
|
indices: None or (total_nnz,)
|
|
attn_mask: None or (batch, max_seqlen_in_batch)
|
|
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
|
"""
|
|
self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
|
|
attn_mask, bias)
|
|
if subset_idx is not None:
|
|
return self.output(index_first_axis(self_output, subset_idx),
|
|
index_first_axis(input_tensor, subset_idx))
|
|
else:
|
|
return self.output(self_output, input_tensor)
|
|
|
|
|
|
class BertGatedLinearUnitMLP(nn.Module):
|
|
"""Applies the FFN at the end of each Mosaic BERT layer.
|
|
|
|
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
|
|
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but
|
|
introduces Gated Linear Units.
|
|
|
|
Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a
|
|
standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with
|
|
`config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed
|
|
with the `config.intermediate_size=3072`.
|
|
However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased
|
|
parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.gated_layers = nn.Linear(config.hidden_size,
|
|
config.intermediate_size * 2,
|
|
bias=False)
|
|
self.act = nn.GELU(approximate='none')
|
|
self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
self.layernorm = nn.LayerNorm(config.hidden_size,
|
|
eps=config.layer_norm_eps)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
"""Compute new hidden states from current hidden states.
|
|
|
|
Args:
|
|
hidden_states (torch.Tensor): The (unpadded) hidden states from
|
|
the attention layer [nnz, dim].
|
|
"""
|
|
residual_connection = hidden_states
|
|
|
|
hidden_states = self.gated_layers(hidden_states)
|
|
gated = hidden_states[:, :self.config.intermediate_size]
|
|
non_gated = hidden_states[:, self.config.intermediate_size:]
|
|
hidden_states = self.act(gated) * non_gated
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
hidden_states = self.wo(hidden_states)
|
|
|
|
hidden_states = self.layernorm(hidden_states + residual_connection)
|
|
return hidden_states
|
|
|
|
|
|
class BertLayer(nn.Module):
|
|
"""Composes the Mosaic BERT attention and FFN blocks into a single layer."""
|
|
|
|
def __init__(self, config):
|
|
super(BertLayer, self).__init__()
|
|
self.attention = BertUnpadAttention(config)
|
|
self.mlp = BertGatedLinearUnitMLP(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
seqlen: int,
|
|
subset_idx: Optional[torch.Tensor] = None,
|
|
indices: Optional[torch.Tensor] = None,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass for a BERT layer, including both attention and MLP.
|
|
|
|
Args:
|
|
hidden_states: (total_nnz, dim)
|
|
cu_seqlens: (batch + 1,)
|
|
seqlen: int
|
|
subset_idx: () set of indices whose values we care about at the end of the layer
|
|
(e.g., the masked tokens, if this is the final layer).
|
|
indices: None or (total_nnz,)
|
|
attn_mask: None or (batch, max_seqlen_in_batch)
|
|
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
|
"""
|
|
attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
|
|
subset_idx, indices, attn_mask, bias)
|
|
layer_output = self.mlp(attention_output)
|
|
return layer_output
|
|
|
|
|
|
class BertEncoder(nn.Module):
|
|
"""A stack of BERT layers providing the backbone of Mosaic BERT.
|
|
|
|
This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertEncoder`,
|
|
but with substantial modifications to implement unpadding and ALiBi.
|
|
|
|
Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
|
|
at padded tokens, and pre-computes attention biases to implement ALiBi.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
layer = BertLayer(config)
|
|
self.layer = nn.ModuleList(
|
|
[copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
|
|
|
|
|
|
|
|
|
|
self._current_alibi_size = int(config.alibi_starting_size)
|
|
self.alibi = torch.zeros(
|
|
(1, self.num_attention_heads, self._current_alibi_size,
|
|
self._current_alibi_size))
|
|
self.rebuild_alibi_tensor(size=config.alibi_starting_size)
|
|
|
|
def rebuild_alibi_tensor(self,
|
|
size: int,
|
|
device: Optional[Union[torch.device, str]] = None):
|
|
|
|
|
|
|
|
|
|
|
|
n_heads = self.num_attention_heads
|
|
|
|
def _get_alibi_head_slopes(n_heads: int) -> List[float]:
|
|
|
|
def get_slopes_power_of_2(n_heads: int) -> List[float]:
|
|
start = (2**(-2**-(math.log2(n_heads) - 3)))
|
|
ratio = start
|
|
return [start * ratio**i for i in range(n_heads)]
|
|
|
|
|
|
|
|
|
|
|
|
if math.log2(n_heads).is_integer():
|
|
return get_slopes_power_of_2(n_heads)
|
|
|
|
closest_power_of_2 = 2**math.floor(math.log2(n_heads))
|
|
slopes_a = get_slopes_power_of_2(closest_power_of_2)
|
|
slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
|
|
slopes_b = slopes_b[0::2][:n_heads - closest_power_of_2]
|
|
return slopes_a + slopes_b
|
|
|
|
context_position = torch.arange(size, device=device)[:, None]
|
|
memory_position = torch.arange(size, device=device)[None, :]
|
|
relative_position = torch.abs(memory_position - context_position)
|
|
|
|
relative_position = relative_position.unsqueeze(0).expand(
|
|
n_heads, -1, -1)
|
|
slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
|
|
alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
|
|
|
|
alibi = alibi.unsqueeze(0)
|
|
assert alibi.shape == torch.Size([1, n_heads, size, size])
|
|
|
|
self._current_alibi_size = size
|
|
self.alibi = alibi
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
output_all_encoded_layers: Optional[bool] = True,
|
|
subset_mask: Optional[torch.Tensor] = None,
|
|
) -> List[torch.Tensor]:
|
|
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
extended_attention_mask = extended_attention_mask.to(
|
|
dtype=torch.float32)
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
|
|
attention_mask_bool = attention_mask.bool()
|
|
batch, seqlen = hidden_states.shape[:2]
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states, indices, cu_seqlens, _ = unpad_input(
|
|
hidden_states, attention_mask_bool)
|
|
|
|
|
|
if self._current_alibi_size < seqlen:
|
|
|
|
warnings.warn(
|
|
f'Increasing alibi size from {self._current_alibi_size} to {seqlen}'
|
|
)
|
|
self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
|
|
elif self.alibi.device != hidden_states.device:
|
|
|
|
self.alibi = self.alibi.to(hidden_states.device)
|
|
alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
|
|
attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
|
|
alibi_attn_mask = attn_bias + alibi_bias
|
|
|
|
all_encoder_layers = []
|
|
if subset_mask is None:
|
|
for layer_module in self.layer:
|
|
hidden_states = layer_module(hidden_states,
|
|
cu_seqlens,
|
|
seqlen,
|
|
None,
|
|
indices,
|
|
attn_mask=attention_mask,
|
|
bias=alibi_attn_mask)
|
|
if output_all_encoded_layers:
|
|
all_encoder_layers.append(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
|
else:
|
|
for i in range(len(self.layer) - 1):
|
|
layer_module = self.layer[i]
|
|
hidden_states = layer_module(hidden_states,
|
|
cu_seqlens,
|
|
seqlen,
|
|
None,
|
|
indices,
|
|
attn_mask=attention_mask,
|
|
bias=alibi_attn_mask)
|
|
if output_all_encoded_layers:
|
|
all_encoder_layers.append(hidden_states)
|
|
subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
|
|
as_tuple=False).flatten()
|
|
hidden_states = self.layer[-1](hidden_states,
|
|
cu_seqlens,
|
|
seqlen,
|
|
subset_idx=subset_idx,
|
|
indices=indices,
|
|
attn_mask=attention_mask,
|
|
bias=alibi_attn_mask)
|
|
|
|
if not output_all_encoded_layers:
|
|
all_encoder_layers.append(hidden_states)
|
|
return all_encoder_layers
|
|
|
|
|
|
class BertPooler(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super(BertPooler, self).__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.activation = nn.Tanh()
|
|
|
|
def forward(self,
|
|
hidden_states: torch.Tensor,
|
|
pool: Optional[bool] = True) -> torch.Tensor:
|
|
|
|
|
|
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
|
|
class BertPredictionHeadTransform(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.transform_act_fn = config.hidden_act
|
|
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.transform_act_fn(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class BertModel(BertPreTrainedModel):
|
|
"""Overall BERT model.
|
|
|
|
Args:
|
|
config: a BertConfig class instance with the configuration to build a new model
|
|
|
|
Inputs:
|
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
|
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
|
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
|
a `sentence B` token (see BERT paper for more details).
|
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
|
a batch has varying length sentences.
|
|
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
|
|
|
Outputs: Tuple of (encoded_layers, pooled_output)
|
|
`encoded_layers`: controlled by `output_all_encoded_layers` argument:
|
|
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
|
|
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
|
|
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
|
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
|
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
|
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
|
classifier pretrained on top of the hidden state associated to the first character of the
|
|
input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
|
|
|
|
Example usage:
|
|
```python
|
|
# Already been converted into WordPiece token ids
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
|
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
|
model = BertModel(config=config)
|
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
|
```
|
|
"""
|
|
|
|
def __init__(self, config, add_pooling_layer=True):
|
|
super(BertModel, self).__init__(config)
|
|
self.embeddings = BertEmbeddings(config)
|
|
self.encoder = BertEncoder(config)
|
|
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
output_all_encoded_layers: Optional[bool] = False,
|
|
masked_tokens_mask: Optional[torch.Tensor] = None,
|
|
**kwargs
|
|
) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.zeros_like(input_ids)
|
|
|
|
embedding_output = self.embeddings(input_ids, token_type_ids,
|
|
position_ids)
|
|
|
|
subset_mask = []
|
|
first_col_mask = []
|
|
|
|
if masked_tokens_mask is None:
|
|
subset_mask = None
|
|
else:
|
|
first_col_mask = torch.zeros_like(masked_tokens_mask)
|
|
first_col_mask[:, 0] = True
|
|
subset_mask = masked_tokens_mask | first_col_mask
|
|
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
attention_mask,
|
|
output_all_encoded_layers=output_all_encoded_layers,
|
|
subset_mask=subset_mask)
|
|
|
|
if masked_tokens_mask is None:
|
|
sequence_output = encoder_outputs[-1]
|
|
pooled_output = self.pooler(
|
|
sequence_output) if self.pooler is not None else None
|
|
else:
|
|
|
|
attention_mask_bool = attention_mask.bool()
|
|
subset_idx = subset_mask[attention_mask_bool]
|
|
sequence_output = encoder_outputs[-1][
|
|
masked_tokens_mask[attention_mask_bool][subset_idx]]
|
|
if self.pooler is not None:
|
|
pool_input = encoder_outputs[-1][
|
|
first_col_mask[attention_mask_bool][subset_idx]]
|
|
pooled_output = self.pooler(pool_input, pool=False)
|
|
else:
|
|
pooled_output = None
|
|
|
|
if not output_all_encoded_layers:
|
|
encoder_outputs = sequence_output
|
|
|
|
if self.pooler is not None:
|
|
return encoder_outputs, pooled_output
|
|
|
|
return encoder_outputs, None
|
|
|
|
|
|
|
|
|
|
|
|
class BertLMPredictionHead(nn.Module):
|
|
|
|
def __init__(self, config, bert_model_embedding_weights):
|
|
super().__init__()
|
|
self.transform = BertPredictionHeadTransform(config)
|
|
|
|
|
|
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
|
bert_model_embedding_weights.size(0))
|
|
self.decoder.weight = bert_model_embedding_weights
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.transform(hidden_states)
|
|
hidden_states = self.decoder(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class BertOnlyMLMHead(nn.Module):
|
|
|
|
def __init__(self, config, bert_model_embedding_weights):
|
|
super().__init__()
|
|
self.predictions = BertLMPredictionHead(config,
|
|
bert_model_embedding_weights)
|
|
|
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
|
prediction_scores = self.predictions(sequence_output)
|
|
return prediction_scores
|
|
|
|
|
|
class BertOnlyNSPHead(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
|
|
|
def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
|
|
seq_relationship_score = self.seq_relationship(pooled_output)
|
|
return seq_relationship_score
|
|
|
|
|
|
|
|
class BertForMaskedLM(BertPreTrainedModel):
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
if config.is_decoder:
|
|
warnings.warn(
|
|
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
|
|
'bi-directional self-attention.')
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False)
|
|
self.cls = BertOnlyMLMHead(config,
|
|
self.bert.embeddings.word_embeddings.weight)
|
|
|
|
|
|
self.post_init()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.cls.predictions.decoder
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.cls.predictions.decoder = new_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (input_ids is not None) == (inputs_embeds is not None):
|
|
raise ValueError('Must specify either input_ids or input_embeds!')
|
|
|
|
if labels is None:
|
|
masked_tokens_mask = None
|
|
else:
|
|
masked_tokens_mask = labels > 0
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.bert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
masked_tokens_mask=masked_tokens_mask,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
prediction_scores = self.cls(sequence_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
masked_token_idx = torch.nonzero(labels.flatten() > 0,
|
|
as_tuple=False).flatten()
|
|
loss = loss_fct(prediction_scores,
|
|
labels.flatten()[masked_token_idx])
|
|
|
|
assert input_ids is not None, 'Coding error; please open an issue'
|
|
batch, seqlen = input_ids.shape[:2]
|
|
prediction_scores = rearrange(index_put_first_axis(
|
|
prediction_scores, masked_token_idx, batch * seqlen),
|
|
'(b s) d -> b s d',
|
|
b=batch)
|
|
|
|
if not return_dict:
|
|
output = (prediction_scores,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return MaskedLMOutput(
|
|
loss=loss,
|
|
logits=prediction_scores,
|
|
hidden_states=outputs[0],
|
|
attentions=None,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
**model_kwargs):
|
|
input_shape = input_ids.shape
|
|
effective_batch_size = input_shape[0]
|
|
|
|
|
|
if self.config.pad_token_id is None:
|
|
raise ValueError('The PAD token should be defined for generation')
|
|
|
|
attention_mask = torch.cat([
|
|
attention_mask,
|
|
attention_mask.new_zeros((attention_mask.shape[0], 1))
|
|
],
|
|
dim=-1)
|
|
dummy_token = torch.full((effective_batch_size, 1),
|
|
self.config.pad_token_id,
|
|
dtype=torch.long,
|
|
device=input_ids.device)
|
|
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
|
|
|
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
|
|
|
|
|
|
|
class BertForSequenceClassification(BertPreTrainedModel):
|
|
"""Bert Model transformer with a sequence classification/regression head.
|
|
|
|
This head is just a linear layer on top of the pooled output. Used for,
|
|
e.g., GLUE tasks.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
self.config = config
|
|
|
|
self.bert = BertModel(config)
|
|
classifier_dropout = (config.classifier_dropout
|
|
if config.classifier_dropout is not None else
|
|
config.hidden_dropout_prob)
|
|
self.dropout = nn.Dropout(classifier_dropout)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.bert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = 'regression'
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or
|
|
labels.dtype == torch.int):
|
|
self.config.problem_type = 'single_label_classification'
|
|
else:
|
|
self.config.problem_type = 'multi_label_classification'
|
|
|
|
if self.config.problem_type == 'regression':
|
|
loss_fct = nn.MSELoss()
|
|
if self.num_labels == 1:
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = loss_fct(logits, labels)
|
|
elif self.config.problem_type == 'single_label_classification':
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels),
|
|
labels.view(-1))
|
|
elif self.config.problem_type == 'multi_label_classification':
|
|
loss_fct = nn.BCEWithLogitsLoss()
|
|
loss = loss_fct(logits, labels)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return SequenceClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs[0],
|
|
attentions=None,
|
|
)
|
|
|
|
|