robinzixuan commited on
Commit
e74bd5e
·
verified ·
1 Parent(s): 6b3a2b8

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_opt.py +143 -0
  2. modeling_opt.py +1733 -0
configuration_opt.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Metaseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """OPT model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class OPTConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`OPTModel`]. It is used to instantiate a OPT model
27
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28
+ defaults will yield a similar configuration to that of the OPT
29
+ [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 50272):
37
+ Vocabulary size of the OPT model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`OPTModel`]
39
+ hidden_size (`int`, *optional*, defaults to 768):
40
+ Dimensionality of the layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 12):
42
+ Number of decoder layers.
43
+ ffn_dim (`int`, *optional*, defaults to 3072):
44
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
45
+ num_attention_heads (`int`, *optional*, defaults to 12):
46
+ Number of attention heads for each attention layer in the Transformer decoder.
47
+ activation_function (`str` or `function`, *optional*, defaults to `"relu"`):
48
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
49
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
50
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
51
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
52
+ just in case (e.g., 512 or 1024 or 2048).
53
+ do_layer_norm_before (`bool`, *optional*, defaults to `True`):
54
+ Whether to perform layer normalization before the attention block.
55
+ word_embed_proj_dim (`int`, *optional*):
56
+ `word_embed_proj_dim` can be set to down-project word embeddings, *e.g.* `opt-350m`. Defaults to
57
+ `hidden_size`.
58
+ dropout (`float`, *optional*, defaults to 0.1):
59
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
60
+ attention_dropout (`float`, *optional*, defaults to 0.0):
61
+ The dropout ratio for the attention probabilities.
62
+ layerdrop (`float`, *optional*, defaults to 0.0):
63
+ The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
64
+ details.
65
+ init_std (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ use_cache (`bool`, *optional*, defaults to `True`):
68
+ Whether or not the model should return the last key/values attentions (not used by all models).
69
+ enable_bias (`bool`, *optional*, defaults to `True`):
70
+ Whether or not if the linear layers in the attention blocks should use the bias term.
71
+ layer_norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
72
+ Whether or not if the layer norms should have learnable parameters.
73
+
74
+ Example:
75
+
76
+ ```python
77
+ >>> from transformers import OPTConfig, OPTModel
78
+
79
+ >>> # Initializing a OPT facebook/opt-large style configuration
80
+ >>> configuration = OPTConfig()
81
+
82
+ >>> # Initializing a model (with random weights) from the facebook/opt-large style configuration
83
+ >>> model = OPTModel(configuration)
84
+
85
+ >>> # Accessing the model configuration
86
+ >>> configuration = model.config
87
+ ```"""
88
+
89
+ model_type = "opt"
90
+ keys_to_ignore_at_inference = ["past_key_values"]
91
+
92
+ def __init__(
93
+ self,
94
+ vocab_size=50272,
95
+ hidden_size=768,
96
+ num_hidden_layers=12,
97
+ ffn_dim=3072,
98
+ max_position_embeddings=2048,
99
+ do_layer_norm_before=True,
100
+ _remove_final_layer_norm=False,
101
+ word_embed_proj_dim=None,
102
+ dropout=0.1,
103
+ attention_dropout=0.0,
104
+ num_attention_heads=12,
105
+ activation_function="relu",
106
+ layerdrop=0.0,
107
+ init_std=0.02,
108
+ use_cache=True,
109
+ pad_token_id=1,
110
+ bos_token_id=2,
111
+ eos_token_id=2,
112
+ enable_bias=True,
113
+ layer_norm_elementwise_affine=True,
114
+ **kwargs,
115
+ ):
116
+ super().__init__(
117
+ pad_token_id=pad_token_id,
118
+ bos_token_id=bos_token_id,
119
+ eos_token_id=eos_token_id,
120
+ **kwargs,
121
+ )
122
+ self.vocab_size = vocab_size
123
+ self.max_position_embeddings = max_position_embeddings
124
+ self.num_attention_heads = num_attention_heads
125
+ self.word_embed_proj_dim = word_embed_proj_dim if word_embed_proj_dim is not None else hidden_size
126
+ self.ffn_dim = ffn_dim
127
+ self.hidden_size = hidden_size
128
+ self.num_hidden_layers = num_hidden_layers
129
+ self.dropout = dropout
130
+ self.attention_dropout = attention_dropout
131
+ self.activation_function = activation_function
132
+ self.init_std = init_std
133
+ self.layerdrop = layerdrop
134
+ self.use_cache = use_cache
135
+ self.do_layer_norm_before = do_layer_norm_before
136
+ # We keep these variables at `True` for backward compatibility.
137
+ self.enable_bias = enable_bias
138
+ self.layer_norm_elementwise_affine = layer_norm_elementwise_affine
139
+
140
+ # Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility
141
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
142
+ # see https://github.com/facebookresearch/metaseq/pull/164
143
+ self._remove_final_layer_norm = _remove_final_layer_norm
modeling_opt.py ADDED
@@ -0,0 +1,1733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch OPT model."""
16
+ import numpy as np
17
+ from typing import List, Optional, Tuple, Union
18
+ from functools import partial
19
+ import torch
20
+
21
+ import torch.nn.functional as F
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPast,
30
+ CausalLMOutputWithPast,
31
+ QuestionAnsweringModelOutput,
32
+ SequenceClassifierOutputWithPast,
33
+ )
34
+ from enum import Flag, auto
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import (
37
+ add_code_sample_docstrings,
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+
41
+ is_flash_attn_2_available,
42
+ is_flash_attn_greater_or_equal_2_10,
43
+ logging,
44
+ replace_return_docstrings,
45
+
46
+ )
47
+ from .configuration_opt import OPTConfig
48
+
49
+
50
+ def logit(p, eps=1e-16):
51
+ p = np.clip(p, eps, 1 - eps)
52
+ return -np.log(1 / p - 1)
53
+
54
+
55
+ class BaseEnumOptions(Flag):
56
+ def __str__(self):
57
+ return self.name
58
+
59
+ @classmethod
60
+ def list_names(cls):
61
+ return [m.name for m in cls]
62
+
63
+
64
+ class AttentionGateType(BaseEnumOptions):
65
+ none = 0
66
+ unconditional_per_head = 1
67
+ conditional_per_head = 2
68
+ conditional_per_token = 3
69
+
70
+
71
+ if is_flash_attn_2_available():
72
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
73
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
74
+
75
+
76
+ logger = logging.get_logger(__name__)
77
+
78
+ _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
79
+ _CONFIG_FOR_DOC = "OPTConfig"
80
+
81
+ # Base model docstring
82
+ _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
83
+
84
+ # SequenceClassification docstring
85
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
86
+ _SEQ_CLASS_EXPECTED_LOSS = 1.71
87
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
88
+
89
+
90
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
91
+ def _get_unpad_data(attention_mask):
92
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
93
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
94
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
95
+ cu_seqlens = F.pad(torch.cumsum(
96
+ seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
97
+ return (
98
+ indices,
99
+ cu_seqlens,
100
+ max_seqlen_in_batch,
101
+ )
102
+
103
+
104
+ class OPTLearnedPositionalEmbedding(nn.Embedding):
105
+ """
106
+ This module learns positional embeddings up to a fixed maximum size.
107
+ """
108
+
109
+ def __init__(self, num_embeddings: int, embedding_dim: int):
110
+ # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
111
+ # and adjust num_embeddings appropriately. Other models don't have this hack
112
+ self.offset = 2
113
+ super().__init__(num_embeddings + self.offset, embedding_dim)
114
+
115
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
116
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
117
+ attention_mask = attention_mask.long()
118
+
119
+ # create positions depending on attention_mask
120
+ positions = (torch.cumsum(attention_mask, dim=1).type_as(
121
+ attention_mask) * attention_mask).long() - 1
122
+
123
+ # cut positions if `past_key_values_length` is > 0
124
+ positions = positions[:, past_key_values_length:]
125
+
126
+ return super().forward(positions + self.offset)
127
+
128
+
129
+ def softmax_n_shifted_zeros(input: torch.Tensor, n: int, dim=-1) -> torch.Tensor:
130
+ """
131
+ $\text(softmax)_n(x_i) = exp(x_i) / (n + \sum_j exp(x_j))$
132
+ Note: softmax_n, with fixed input, is _not_ shift-symmetric when n != 0
133
+ """
134
+ # compute the maxes along the last dimension
135
+ input_maxes = input.max(dim=dim, keepdim=True).values
136
+ # shift the input to prevent overflow (and underflow in the denominator)
137
+ shifted_inputs = torch.subtract(input, input_maxes)
138
+ # compute the numerator and softmax_0 denominator using the shifted input
139
+ numerator = torch.exp(shifted_inputs)
140
+ original_denominator = numerator.sum(dim=dim, keepdim=True)
141
+ # we need to shift the zeros in the same way we shifted the inputs
142
+ shifted_zeros = torch.multiply(input_maxes, -1)
143
+ # and then add this contribution to the denominator
144
+ denominator = torch.add(original_denominator,
145
+ torch.multiply(torch.exp(shifted_zeros), n))
146
+ return torch.divide(numerator, denominator)
147
+
148
+
149
+ def softmax_1(input: torch.Tensor, dim=-1, dtype=torch.float32) -> torch.Tensor:
150
+ """
151
+ $\text(softmax)_n(x_i) = exp(x_i) / (1 + \sum_j exp(x_j))$
152
+ """
153
+ output = softmax_n_shifted_zeros(input, 1, dim=dim)
154
+ return output if dtype is None else output.type(dtype=dtype)
155
+
156
+
157
+ def clipped_softmax(data, dim=1, eta=1.1, gamma=-0.1, **kw):
158
+ sm_out = torch.nn.functional.softmax(data, dim=dim, **kw)
159
+ stretched_out = sm_out * (eta - gamma) + gamma
160
+ return torch.clip(stretched_out, 0, 1)
161
+
162
+
163
+ def clipped_softmax1(data, dim=1, eta=1.1, gamma=-0.1, **kw):
164
+ sm_out = softmax_1(data, dim=dim, **kw)
165
+ stretched_out = sm_out * (eta - gamma) + gamma
166
+ return torch.clip(stretched_out, 0, 1)
167
+
168
+
169
+ class OPTAttention(nn.Module):
170
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
171
+
172
+ def __init__(
173
+ self,
174
+ config: OPTConfig,
175
+ dropout: float = 0.0,
176
+ is_decoder: bool = False,
177
+ bias: bool = True,
178
+ # new
179
+ softmax_fn=nn.functional.softmax,
180
+ alpha=None,
181
+ max_seq_length=512,
182
+ ssm_eps=None,
183
+ tau=None,
184
+ skip_attn=False,
185
+ attn_gate_type=AttentionGateType.conditional_per_token,
186
+ attn_gate_init=0.25,
187
+ attn_gate_mlp=False,
188
+ attn_gate_mlp2=False,
189
+ attn_gate_linear_all_features=False,
190
+ fine_tuning=False,
191
+ attn_softmax='Vanilla',
192
+ ):
193
+ super().__init__()
194
+ self.embed_dim = config.hidden_size
195
+ self.num_heads = config.num_attention_heads
196
+ self.dropout = config.attention_dropout
197
+ self.enable_bias = config.enable_bias
198
+ self.head_dim = self.embed_dim // self.num_heads
199
+ self.is_causal = True
200
+
201
+ if (self.head_dim * self.num_heads) != self.embed_dim:
202
+ raise ValueError(
203
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
204
+ f" and `num_heads`: {self.num_heads})."
205
+ )
206
+ self.scaling = self.head_dim**-0.5
207
+ self.is_decoder = is_decoder
208
+
209
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
210
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
211
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
212
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
213
+
214
+ # YB: capture the input and output of the softmax
215
+ self.attn_scores = nn.Identity() # before attention mask
216
+ self.attn_probs_before_dropout = nn.Identity()
217
+ self.attn_probs_after_dropout = nn.Identity()
218
+
219
+ self.alpha = alpha
220
+ self.max_seq_length = max_seq_length
221
+ self.ssm_eps = ssm_eps
222
+ self.tau = tau
223
+ self.attn_softmax = attn_softmax
224
+
225
+ # define softmax function
226
+ if self.alpha is not None:
227
+ assert self.max_seq_length is not None
228
+ gamma = -self.alpha / self.max_seq_length
229
+ if self.attn_softmax is "softmax1":
230
+ print("Using clipped Softmax_1!")
231
+ self.softmax_fn = partial(
232
+ clipped_softmax1, gamma=gamma, eta=1.0)
233
+ else:
234
+ self.softmax_fn = partial(
235
+ clipped_softmax, gamma=gamma, eta=1.0)
236
+ else:
237
+ self.softmax_fn = softmax_fn
238
+
239
+ self.skip_attn = skip_attn
240
+
241
+ # attention gating
242
+ self.last_gate_avg_prob = None
243
+ self.last_gate_all_probs = None
244
+
245
+ self.attn_gate_type = attn_gate_type
246
+ self.attn_gate_init = attn_gate_init
247
+ self.attn_gate_mlp = attn_gate_mlp
248
+ self.attn_gate_mlp2 = attn_gate_mlp2
249
+ self.attn_gate_linear_all_features = attn_gate_linear_all_features
250
+
251
+ self.alpha = None
252
+ self.ssm_eps = ssm_eps
253
+ self.gate_fn = torch.sigmoid
254
+ self.pooling_fn = partial(torch.mean, dim=1, keepdims=True)
255
+
256
+ self.fine_tuning = fine_tuning
257
+
258
+ # gate scaling factor
259
+ self.gate_scaling_factor = 1.0
260
+ if self.fine_tuning and self.attn_gate_init is not None:
261
+ self.gate_scaling_factor = 1.0 / self.attn_gate_init
262
+
263
+ # define gate
264
+ if self.attn_gate_type == AttentionGateType.unconditional_per_head:
265
+ init_alpha = torch.zeros(size=(self.num_heads,))
266
+ self.alpha = nn.Parameter(init_alpha, requires_grad=True)
267
+
268
+ elif self.attn_gate_type in (
269
+ AttentionGateType.conditional_per_head,
270
+ AttentionGateType.conditional_per_token,
271
+ ):
272
+ if self.attn_gate_linear_all_features:
273
+ self.alpha = nn.Linear(
274
+ self.embed_dim, self.num_heads, bias=True)
275
+
276
+ else: # separate predictors for each head
277
+ module_list = []
278
+ for _ in range(self.num_heads):
279
+ if self.attn_gate_mlp:
280
+ fc = nn.Sequential(
281
+ nn.Linear(self.head_dim,
282
+ self.head_dim // 4, bias=True),
283
+ nn.ReLU(),
284
+ nn.Linear(self.head_dim // 4, 1, bias=True),
285
+ )
286
+ elif self.attn_gate_mlp2:
287
+ fc = nn.Sequential(
288
+ nn.Linear(self.head_dim, self.head_dim, bias=True),
289
+ nn.ReLU(),
290
+ nn.Linear(self.head_dim, 1, bias=True),
291
+ )
292
+ else:
293
+ fc = nn.Linear(self.head_dim, 1, bias=True)
294
+
295
+ if self.attn_gate_init is not None:
296
+ init_bias = logit(self.attn_gate_init)
297
+ torch.nn.init.constant_(fc.bias, init_bias)
298
+
299
+ if self.fine_tuning:
300
+ # init to a very small values
301
+ torch.nn.init.normal_(
302
+ fc.weight, mean=0.0, std=0.001)
303
+
304
+ module_list.append(fc)
305
+ self.alpha = nn.ModuleList(module_list)
306
+
307
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
308
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
309
+
310
+ def forward(
311
+ self,
312
+ hidden_states: torch.Tensor,
313
+ key_value_states: Optional[torch.Tensor] = None,
314
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ layer_head_mask: Optional[torch.Tensor] = None,
317
+ output_attentions: bool = False,
318
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
319
+ """Input shape: Batch x Time x Channel"""
320
+
321
+ # if key_value_states are provided this layer is used as a cross-attention layer
322
+ # for the decoder
323
+ is_cross_attention = key_value_states is not None
324
+
325
+ bsz, tgt_len, _ = hidden_states.size()
326
+
327
+ # get query proj
328
+ query_states = self.q_proj(hidden_states) * self.scaling
329
+ # get key, value proj
330
+ if is_cross_attention and past_key_value is not None:
331
+ # reuse k,v, cross_attentions
332
+ key_states = past_key_value[0]
333
+ value_states = past_key_value[1]
334
+ elif is_cross_attention:
335
+ # cross_attentions
336
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
337
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
338
+ elif past_key_value is not None:
339
+ # reuse k, v, self_attention
340
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
341
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
342
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
343
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
344
+ else:
345
+ # self_attention
346
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
347
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
348
+
349
+ if self.is_decoder:
350
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
351
+ # Further calls to cross_attention layer can then reuse all cross-attention
352
+ # key/value_states (first "if" case)
353
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
354
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
355
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
356
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
357
+ past_key_value = (key_states, value_states)
358
+
359
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
360
+ query_states = self._shape(
361
+ query_states, tgt_len, bsz).view(*proj_shape)
362
+ key_states = key_states.view(*proj_shape)
363
+ value_states = value_states.view(*proj_shape)
364
+
365
+ src_len = key_states.size(1)
366
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
367
+
368
+ # YB: for logging softmax input
369
+ attn_weights = self.attn_scores(attn_weights)
370
+
371
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
372
+ raise ValueError(
373
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
374
+ f" {attn_weights.size()}"
375
+ )
376
+
377
+ if attention_mask is not None:
378
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
379
+ raise ValueError(
380
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
381
+ )
382
+ attn_weights = attn_weights.view(
383
+ bsz, self.num_heads, tgt_len, src_len) + attention_mask
384
+ attn_weights = torch.max(
385
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
386
+ )
387
+ attn_weights = attn_weights.view(
388
+ bsz * self.num_heads, tgt_len, src_len)
389
+
390
+ # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
391
+ if attn_weights.dtype == torch.float16:
392
+ attn_weights = self.softmax_fn(attn_weights, dim=-1, dtype=torch.float32).to(
393
+ torch.float16
394
+ )
395
+ else:
396
+ attn_weights = self.softmax_fn(attn_weights, dim=-1)
397
+
398
+ if layer_head_mask is not None:
399
+ if layer_head_mask.size() != (self.num_heads,):
400
+ raise ValueError(
401
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
402
+ f" {layer_head_mask.size()}"
403
+ )
404
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
405
+ bsz, self.num_heads, tgt_len, src_len
406
+ )
407
+ attn_weights = attn_weights.view(
408
+ bsz * self.num_heads, tgt_len, src_len)
409
+
410
+ if output_attentions:
411
+ # this operation is a bit awkward, but it's required to
412
+ # make sure that attn_weights keeps its gradient.
413
+ # In order to do so, attn_weights have to be reshaped
414
+ # twice and have to be reused in the following
415
+ attn_weights_reshaped = attn_weights.view(
416
+ bsz, self.num_heads, tgt_len, src_len)
417
+ attn_weights = attn_weights_reshaped.view(
418
+ bsz * self.num_heads, tgt_len, src_len)
419
+ else:
420
+ attn_weights_reshaped = None
421
+
422
+ # YB: for logging softmax output
423
+ attn_weights = self.attn_probs_before_dropout(attn_weights)
424
+
425
+ attn_probs = nn.functional.dropout(
426
+ attn_weights, p=self.dropout, training=self.training)
427
+
428
+ # YB: for logging softmax output
429
+ attn_probs = self.attn_probs_after_dropout(attn_probs)
430
+
431
+ attn_output = torch.bmm(attn_probs, value_states)
432
+
433
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
434
+ raise ValueError(
435
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
436
+ f" {attn_output.size()}"
437
+ )
438
+
439
+ attn_output = attn_output.view(
440
+ bsz, self.num_heads, tgt_len, self.head_dim)
441
+ # attn_output - (B, H, T, d_head)
442
+
443
+ #
444
+ # *** Gating ***
445
+ if self.attn_gate_type == AttentionGateType.unconditional_per_head:
446
+ gate = self.gate_fn(self.alpha) # (H,)
447
+ attn_output *= gate.view(-1, 1, 1) # (B, H, T, d_head)
448
+
449
+ self.last_gate_avg_prob = gate.view(-1)
450
+
451
+ elif self.attn_gate_type in (
452
+ AttentionGateType.conditional_per_head,
453
+ AttentionGateType.conditional_per_token,
454
+ ):
455
+ x = hidden_states # (B, T, d_model)
456
+
457
+ if self.attn_gate_linear_all_features: # assume per_token
458
+ alpha = self.alpha(x) # (B, T, H)
459
+ gate = self.gate_fn(alpha)
460
+ gate = gate.permute(0, 2, 1).contiguous() # (B, H, T)
461
+ gate = gate.unsqueeze(3) # (B, H, T, 1)
462
+
463
+ else:
464
+ # x = self.transpose_for_scores(x) # (B, H, T, d_head)
465
+ x = self._shape(x, -1, bsz) # (B, H, T, d_head)
466
+
467
+ alpha = []
468
+ for head_idx in range(self.num_heads):
469
+ x_head = x[:, head_idx, ...] # (B, T, d_head)
470
+ fc_head = self.alpha[head_idx]
471
+ alpha_head = fc_head(x_head) # (B, T, 1)
472
+ if self.attn_gate_type == AttentionGateType.conditional_per_head:
473
+ alpha_head = self.pooling_fn(alpha_head) # (B, 1, 1)
474
+ alpha.append(alpha_head)
475
+ alpha = torch.stack(alpha, dim=1) # (B, H, *, 1)
476
+ gate = self.gate_fn(alpha)
477
+
478
+ attn_output *= gate * self.gate_scaling_factor
479
+
480
+ self.last_gate_all_probs = gate # all gates to see the distributions
481
+ avg_gate = gate.mean(dim=0)
482
+ self.last_gate_avg_prob = avg_gate.view(
483
+ self.num_heads, -1).mean(dim=1)
484
+
485
+ #
486
+ # <end elif>
487
+
488
+ attn_output = attn_output.transpose(1, 2)
489
+
490
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
491
+ # partitioned aross GPUs when using tensor-parallelism.
492
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
493
+
494
+ attn_output = self.out_proj(attn_output)
495
+
496
+ return attn_output, attn_weights_reshaped, past_key_value
497
+
498
+
499
+ class OptFlashAttention2(OPTAttention):
500
+ """
501
+ OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
502
+ The only required change would be on the forward pass where it needs to correctly call the public API of flash
503
+ attention and deal with padding tokens in case the input contains any of them.
504
+ """
505
+
506
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
507
+ def __init__(self, *args, **kwargs):
508
+ super().__init__(*args, **kwargs)
509
+
510
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
511
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
512
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
513
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
514
+
515
+ def forward(
516
+ self,
517
+ hidden_states: torch.Tensor,
518
+ key_value_states: Optional[torch.Tensor] = None,
519
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
520
+ attention_mask: Optional[torch.Tensor] = None,
521
+ layer_head_mask: Optional[torch.Tensor] = None,
522
+ output_attentions: bool = False,
523
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
524
+ """Input shape: Batch x Time x Channel"""
525
+
526
+ # if key_value_states are provided this layer is used as a cross-attention layer
527
+ # for the decoder
528
+ is_cross_attention = key_value_states is not None
529
+
530
+ bsz, _, _ = hidden_states.size()
531
+
532
+ # get query proj
533
+ query_states = self.q_proj(hidden_states)
534
+ # get key, value proj
535
+ if is_cross_attention and past_key_value is not None:
536
+ # reuse k,v, cross_attentions
537
+ key_states = past_key_value[0]
538
+ value_states = past_key_value[1]
539
+ elif is_cross_attention:
540
+ # cross_attentions
541
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
542
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
543
+ elif past_key_value is not None:
544
+ # reuse k, v, self_attention
545
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
546
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
547
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
548
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
549
+ else:
550
+ # self_attention
551
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
552
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
553
+
554
+ if self.is_decoder:
555
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
556
+ # Further calls to cross_attention layer can then reuse all cross-attention
557
+ # key/value_states (first "if" case)
558
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
559
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
560
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
561
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
562
+ past_key_value = (key_states, value_states)
563
+
564
+ query_length = query_states.shape[1]
565
+ tgt_len = key_states.shape[-2]
566
+
567
+ # Flash attention requires the input to have the shape
568
+ # batch_size x seq_length x head_dim x hidden_dim
569
+ query_states = query_states.view(
570
+ bsz, query_length, self.num_heads, self.head_dim)
571
+ key_states = key_states.transpose(1, 2).view(
572
+ bsz, tgt_len, self.num_heads, self.head_dim)
573
+ value_states = value_states.transpose(1, 2).view(
574
+ bsz, tgt_len, self.num_heads, self.head_dim)
575
+
576
+ attn_dropout = self.dropout if self.training else 0.0
577
+
578
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
579
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
580
+ # cast them back in float16 just to be sure everything works as expected.
581
+ input_dtype = query_states.dtype
582
+ if input_dtype == torch.float32:
583
+ if torch.is_autocast_enabled():
584
+ target_dtype = torch.get_autocast_gpu_dtype()
585
+ # Handle the case where the model is quantized
586
+ elif hasattr(self.config, "_pre_quantization_dtype"):
587
+ target_dtype = self.config._pre_quantization_dtype
588
+ else:
589
+ target_dtype = self.q_proj.weight.dtype
590
+
591
+ logger.warning_once(
592
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
593
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
594
+ f" {target_dtype}."
595
+ )
596
+
597
+ query_states = query_states.to(target_dtype)
598
+ key_states = key_states.to(target_dtype)
599
+ value_states = value_states.to(target_dtype)
600
+
601
+ attn_output = self._flash_attention_forward(
602
+ query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout
603
+ )
604
+
605
+ attn_weights_reshaped = attn_output.reshape(
606
+ bsz, query_length, self.num_heads * self.head_dim)
607
+ attn_output = self.out_proj(attn_weights_reshaped)
608
+
609
+ if not output_attentions:
610
+ attn_weights_reshaped = None
611
+
612
+ return attn_output, attn_weights_reshaped, past_key_value
613
+
614
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
615
+ def _flash_attention_forward(
616
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
617
+ ):
618
+ """
619
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
620
+ first unpad the input, then computes the attention scores and pad the final attention scores.
621
+
622
+ Args:
623
+ query_states (`torch.Tensor`):
624
+ Input query states to be passed to Flash Attention API
625
+ key_states (`torch.Tensor`):
626
+ Input key states to be passed to Flash Attention API
627
+ value_states (`torch.Tensor`):
628
+ Input value states to be passed to Flash Attention API
629
+ attention_mask (`torch.Tensor`):
630
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
631
+ position of padding tokens and 1 for the position of non-padding tokens.
632
+ dropout (`float`):
633
+ Attention dropout
634
+ softmax_scale (`float`, *optional*):
635
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
636
+ """
637
+ if not self._flash_attn_uses_top_left_mask:
638
+ causal = self.is_causal
639
+ else:
640
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
641
+ causal = self.is_causal and query_length != 1
642
+
643
+ # Contains at least one padding token in the sequence
644
+ if attention_mask is not None:
645
+ batch_size = query_states.shape[0]
646
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
647
+ query_states, key_states, value_states, attention_mask, query_length
648
+ )
649
+
650
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
651
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
652
+
653
+ attn_output_unpad = flash_attn_varlen_func(
654
+ query_states,
655
+ key_states,
656
+ value_states,
657
+ cu_seqlens_q=cu_seqlens_q,
658
+ cu_seqlens_k=cu_seqlens_k,
659
+ max_seqlen_q=max_seqlen_in_batch_q,
660
+ max_seqlen_k=max_seqlen_in_batch_k,
661
+ dropout_p=dropout,
662
+ softmax_scale=softmax_scale,
663
+ causal=causal,
664
+ )
665
+
666
+ attn_output = pad_input(
667
+ attn_output_unpad, indices_q, batch_size, query_length)
668
+ else:
669
+ attn_output = flash_attn_func(
670
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
671
+ )
672
+
673
+ return attn_output
674
+
675
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
676
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
677
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
678
+ attention_mask)
679
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
680
+
681
+ key_layer = index_first_axis(
682
+ key_layer.reshape(batch_size * kv_seq_len,
683
+ num_key_value_heads, head_dim), indices_k
684
+ )
685
+ value_layer = index_first_axis(
686
+ value_layer.reshape(batch_size * kv_seq_len,
687
+ num_key_value_heads, head_dim), indices_k
688
+ )
689
+ if query_length == kv_seq_len:
690
+ query_layer = index_first_axis(
691
+ query_layer.reshape(batch_size * kv_seq_len,
692
+ self.num_heads, head_dim), indices_k
693
+ )
694
+ cu_seqlens_q = cu_seqlens_k
695
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
696
+ indices_q = indices_k
697
+ elif query_length == 1:
698
+ max_seqlen_in_batch_q = 1
699
+ cu_seqlens_q = torch.arange(
700
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
701
+ ) # There is a memcpy here, that is very bad.
702
+ indices_q = cu_seqlens_q[:-1]
703
+ query_layer = query_layer.squeeze(1)
704
+ else:
705
+ # The -q_len: slice assumes left padding.
706
+ attention_mask = attention_mask[:, -query_length:]
707
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
708
+ query_layer, attention_mask)
709
+
710
+ return (
711
+ query_layer,
712
+ key_layer,
713
+ value_layer,
714
+ indices_q,
715
+ (cu_seqlens_q, cu_seqlens_k),
716
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
717
+ )
718
+
719
+
720
+ OPT_ATTENTION_CLASSES = {
721
+ "eager": OPTAttention,
722
+ "flash_attention_2": OptFlashAttention2,
723
+ }
724
+
725
+
726
+ class OPTDecoderLayer(nn.Module):
727
+ def __init__(self, config: OPTConfig):
728
+ super().__init__()
729
+ self.embed_dim = config.hidden_size
730
+
731
+ self.self_attn = OPTAttention(
732
+ config=config, is_decoder=True)
733
+
734
+ self.do_layer_norm_before = config.do_layer_norm_before
735
+ self.dropout = config.dropout
736
+ self.activation_fn = ACT2FN[config.activation_function]
737
+
738
+ self.self_attn_layer_norm = nn.LayerNorm(
739
+ self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
740
+ )
741
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim,
742
+ bias=config.enable_bias)
743
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim,
744
+ bias=config.enable_bias)
745
+ self.final_layer_norm = nn.LayerNorm(
746
+ self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
747
+
748
+ def forward(
749
+ self,
750
+ hidden_states: torch.Tensor,
751
+ attention_mask: Optional[torch.Tensor] = None,
752
+ layer_head_mask: Optional[torch.Tensor] = None,
753
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
754
+ output_attentions: Optional[bool] = False,
755
+ use_cache: Optional[bool] = False,
756
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
757
+ """
758
+ Args:
759
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
760
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
761
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
762
+ layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
763
+ `(encoder_attention_heads,)`.
764
+ output_attentions (`bool`, *optional*):
765
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
766
+ returned tensors for more detail.
767
+ use_cache (`bool`, *optional*):
768
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
769
+ (see `past_key_values`).
770
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
771
+ """
772
+
773
+ residual = hidden_states
774
+
775
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
776
+ if self.do_layer_norm_before:
777
+ hidden_states = self.self_attn_layer_norm(hidden_states)
778
+
779
+ # Self Attention
780
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
781
+ hidden_states=hidden_states,
782
+ past_key_value=past_key_value,
783
+ attention_mask=attention_mask,
784
+ layer_head_mask=layer_head_mask,
785
+ output_attentions=output_attentions,
786
+ )
787
+ hidden_states = nn.functional.dropout(
788
+ hidden_states, p=self.dropout, training=self.training)
789
+ hidden_states = residual + hidden_states
790
+
791
+ # 350m applies layer norm AFTER attention
792
+ if not self.do_layer_norm_before:
793
+ hidden_states = self.self_attn_layer_norm(hidden_states)
794
+
795
+ # Fully Connected
796
+ hidden_states_shape = hidden_states.shape
797
+ hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
798
+ residual = hidden_states
799
+
800
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
801
+ if self.do_layer_norm_before:
802
+ hidden_states = self.final_layer_norm(hidden_states)
803
+
804
+ hidden_states = self.fc1(hidden_states)
805
+ hidden_states = self.activation_fn(hidden_states)
806
+
807
+ hidden_states = self.fc2(hidden_states)
808
+ hidden_states = nn.functional.dropout(
809
+ hidden_states, p=self.dropout, training=self.training)
810
+
811
+ hidden_states = (residual + hidden_states).view(hidden_states_shape)
812
+
813
+ # 350m applies layer norm AFTER attention
814
+ if not self.do_layer_norm_before:
815
+ hidden_states = self.final_layer_norm(hidden_states)
816
+
817
+ outputs = (hidden_states,)
818
+
819
+ if output_attentions:
820
+ outputs += (self_attn_weights,)
821
+
822
+ if use_cache:
823
+ outputs += (present_key_value,)
824
+
825
+ return outputs
826
+
827
+
828
+ OPT_START_DOCSTRING = r"""
829
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
830
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
831
+ etc.)
832
+
833
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
834
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
835
+ and behavior.
836
+
837
+ Parameters:
838
+ config ([`OPTConfig`]):
839
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
840
+ load the weights associated with the model, only the configuration. Check out the
841
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
842
+ """
843
+
844
+
845
+ @add_start_docstrings(
846
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
847
+ OPT_START_DOCSTRING,
848
+ )
849
+ class OPTPreTrainedModel(PreTrainedModel):
850
+ config_class = OPTConfig
851
+ base_model_prefix = "model"
852
+ supports_gradient_checkpointing = True
853
+ _no_split_modules = ["OPTDecoderLayer"]
854
+ _supports_flash_attn_2 = True
855
+
856
+ def _init_weights(self, module):
857
+ std = self.config.init_std
858
+ if isinstance(module, nn.Linear):
859
+ module.weight.data.normal_(mean=0.0, std=std)
860
+ if module.bias is not None:
861
+ module.bias.data.zero_()
862
+ elif isinstance(module, nn.Embedding):
863
+ module.weight.data.normal_(mean=0.0, std=std)
864
+ if module.padding_idx is not None:
865
+ module.weight.data[module.padding_idx].zero_()
866
+
867
+
868
+ OPT_INPUTS_DOCSTRING = r"""
869
+ Args:
870
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
871
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
872
+ it.
873
+
874
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
875
+ [`PreTrainedTokenizer.__call__`] for details.
876
+
877
+ [What are input IDs?](../glossary#input-ids)
878
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
879
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
880
+
881
+ - 1 for tokens that are **not masked**,
882
+ - 0 for tokens that are **masked**.
883
+
884
+ [What are attention masks?](../glossary#attention-mask)
885
+
886
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
887
+ [`PreTrainedTokenizer.__call__`] for details.
888
+
889
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
890
+ `past_key_values`).
891
+
892
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
893
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
894
+ information on the default strategy.
895
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
896
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
897
+
898
+ - 1 indicates the head is **not masked**,
899
+ - 0 indicates the head is **masked**.
900
+
901
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
902
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
903
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
904
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
905
+
906
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
907
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
908
+
909
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
910
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
911
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
912
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
913
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
914
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
915
+ model's internal embedding lookup matrix.
916
+ use_cache (`bool`, *optional*):
917
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
918
+ `past_key_values`).
919
+ output_attentions (`bool`, *optional*):
920
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
921
+ tensors for more detail.
922
+ output_hidden_states (`bool`, *optional*):
923
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
924
+ more detail.
925
+ return_dict (`bool`, *optional*):
926
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
927
+ """
928
+
929
+
930
+ class OPTDecoder(OPTPreTrainedModel):
931
+ """
932
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
933
+
934
+ Args:
935
+ config: OPTConfig
936
+ """
937
+
938
+ def __init__(self, config: OPTConfig):
939
+ super().__init__(config)
940
+ self.dropout = config.dropout
941
+ self.layerdrop = config.layerdrop
942
+ self.padding_idx = config.pad_token_id
943
+ self.max_target_positions = config.max_position_embeddings
944
+ self.vocab_size = config.vocab_size
945
+
946
+ self.embed_tokens = nn.Embedding(
947
+ config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
948
+ self.embed_positions = OPTLearnedPositionalEmbedding(
949
+ config.max_position_embeddings, config.hidden_size)
950
+
951
+ if config.word_embed_proj_dim != config.hidden_size:
952
+ self.project_out = nn.Linear(
953
+ config.hidden_size, config.word_embed_proj_dim, bias=False)
954
+ else:
955
+ self.project_out = None
956
+
957
+ if config.word_embed_proj_dim != config.hidden_size:
958
+ self.project_in = nn.Linear(
959
+ config.word_embed_proj_dim, config.hidden_size, bias=False)
960
+ else:
961
+ self.project_in = None
962
+
963
+ # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
964
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
965
+ # see https://github.com/facebookresearch/metaseq/pull/164
966
+ if config.do_layer_norm_before and not config._remove_final_layer_norm:
967
+ self.final_layer_norm = nn.LayerNorm(
968
+ config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
969
+ )
970
+ else:
971
+ self.final_layer_norm = None
972
+
973
+ self.layers = nn.ModuleList(
974
+ [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
975
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
976
+
977
+ self.gradient_checkpointing = False
978
+ # Initialize weights and apply final processing
979
+ self.post_init()
980
+
981
+ def get_input_embeddings(self):
982
+ return self.embed_tokens
983
+
984
+ def set_input_embeddings(self, value):
985
+ self.embed_tokens = value
986
+
987
+ def forward(
988
+ self,
989
+ input_ids: torch.LongTensor = None,
990
+ attention_mask: Optional[torch.Tensor] = None,
991
+ head_mask: Optional[torch.Tensor] = None,
992
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
993
+ inputs_embeds: Optional[torch.FloatTensor] = None,
994
+ use_cache: Optional[bool] = None,
995
+ output_attentions: Optional[bool] = None,
996
+ output_hidden_states: Optional[bool] = None,
997
+ return_dict: Optional[bool] = None,
998
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
999
+ r"""
1000
+ Args:
1001
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1002
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1003
+ provide it.
1004
+
1005
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1006
+ [`PreTrainedTokenizer.__call__`] for details.
1007
+
1008
+ [What are input IDs?](../glossary#input-ids)
1009
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1010
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1011
+
1012
+ - 1 for tokens that are **not masked**,
1013
+ - 0 for tokens that are **masked**.
1014
+
1015
+ [What are attention masks?](../glossary#attention-mask)
1016
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
1017
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1018
+
1019
+ - 1 indicates the head is **not masked**,
1020
+ - 0 indicates the head is **masked**.
1021
+
1022
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1023
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1024
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1025
+
1026
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1027
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1028
+
1029
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1030
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1031
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1032
+
1033
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1034
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1035
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1036
+ than the model's internal embedding lookup matrix.
1037
+ output_attentions (`bool`, *optional*):
1038
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1039
+ returned tensors for more detail.
1040
+ output_hidden_states (`bool`, *optional*):
1041
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1042
+ for more detail.
1043
+ return_dict (`bool`, *optional*):
1044
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1045
+ """
1046
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1047
+ output_hidden_states = (
1048
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1049
+ )
1050
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1051
+
1052
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1053
+
1054
+ # retrieve input_ids and inputs_embeds
1055
+ if input_ids is not None and inputs_embeds is not None:
1056
+ raise ValueError(
1057
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1058
+ elif input_ids is not None:
1059
+ input_shape = input_ids.size()
1060
+ input_ids = input_ids.view(-1, input_shape[-1])
1061
+ elif inputs_embeds is not None:
1062
+ input_shape = inputs_embeds.size()[:-1]
1063
+ else:
1064
+ raise ValueError(
1065
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds")
1066
+
1067
+ if inputs_embeds is None:
1068
+ inputs_embeds = self.embed_tokens(input_ids)
1069
+
1070
+ batch_size, seq_length = input_shape
1071
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1072
+ # required mask seq length can be calculated via length of past
1073
+ mask_seq_length = past_key_values_length + seq_length
1074
+
1075
+ # embed positions
1076
+ if self._use_flash_attention_2:
1077
+ # 2d mask is passed through the layers
1078
+ causal_attention_mask = attention_mask if (
1079
+ attention_mask is not None and 0 in attention_mask) else None
1080
+ attention_mask = (
1081
+ torch.ones(batch_size, mask_seq_length,
1082
+ device=inputs_embeds.device)
1083
+ if attention_mask is None
1084
+ else attention_mask
1085
+ )
1086
+ else:
1087
+ # 4d mask is passed through the layers
1088
+ if attention_mask is None:
1089
+ attention_mask = torch.ones(
1090
+ batch_size, mask_seq_length, device=inputs_embeds.device)
1091
+ elif attention_mask.shape[1] != mask_seq_length:
1092
+ raise ValueError(
1093
+ f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
1094
+ f"{mask_seq_length} (sum of the lengths of current and past inputs)"
1095
+ )
1096
+ causal_attention_mask = _prepare_4d_causal_attention_mask(
1097
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
1098
+ )
1099
+
1100
+ pos_embeds = self.embed_positions(
1101
+ attention_mask, past_key_values_length)
1102
+
1103
+ if self.project_in is not None:
1104
+ inputs_embeds = self.project_in(inputs_embeds)
1105
+
1106
+ hidden_states = inputs_embeds + pos_embeds
1107
+
1108
+ if self.gradient_checkpointing and self.training:
1109
+ if use_cache:
1110
+ logger.warning_once(
1111
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1112
+ )
1113
+ use_cache = False
1114
+
1115
+ # decoder layers
1116
+ all_hidden_states = () if output_hidden_states else None
1117
+ all_self_attns = () if output_attentions else None
1118
+ next_decoder_cache = () if use_cache else None
1119
+
1120
+ # check if head_mask has a correct number of layers specified if desired
1121
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
1122
+ if attn_mask is not None:
1123
+ if attn_mask.size()[0] != (len(self.layers)):
1124
+ raise ValueError(
1125
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1126
+ f" {head_mask.size()[0]}."
1127
+ )
1128
+
1129
+ for idx, decoder_layer in enumerate(self.layers):
1130
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1131
+ if output_hidden_states:
1132
+ all_hidden_states += (hidden_states,)
1133
+
1134
+ if self.training:
1135
+ dropout_probability = torch.rand([])
1136
+ if dropout_probability < self.layerdrop:
1137
+ continue
1138
+
1139
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1140
+
1141
+ if self.gradient_checkpointing and self.training:
1142
+ layer_outputs = self._gradient_checkpointing_func(
1143
+ decoder_layer.__call__,
1144
+ hidden_states,
1145
+ causal_attention_mask,
1146
+ head_mask[idx] if head_mask is not None else None,
1147
+ None,
1148
+ output_attentions,
1149
+ use_cache,
1150
+ )
1151
+ else:
1152
+ layer_outputs = decoder_layer(
1153
+ hidden_states,
1154
+ attention_mask=causal_attention_mask,
1155
+ layer_head_mask=(
1156
+ head_mask[idx] if head_mask is not None else None),
1157
+ past_key_value=past_key_value,
1158
+ output_attentions=output_attentions,
1159
+ use_cache=use_cache,
1160
+ )
1161
+
1162
+ hidden_states = layer_outputs[0]
1163
+
1164
+ if use_cache:
1165
+ next_decoder_cache += (
1166
+ layer_outputs[2 if output_attentions else 1],)
1167
+
1168
+ if output_attentions:
1169
+ all_self_attns += (layer_outputs[1],)
1170
+
1171
+ if self.final_layer_norm is not None:
1172
+ hidden_states = self.final_layer_norm(hidden_states)
1173
+
1174
+ if self.project_out is not None:
1175
+ hidden_states = self.project_out(hidden_states)
1176
+
1177
+ # add hidden states from the last decoder layer
1178
+ if output_hidden_states:
1179
+ all_hidden_states += (hidden_states,)
1180
+
1181
+ next_cache = next_decoder_cache if use_cache else None
1182
+ if not return_dict:
1183
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1184
+ return BaseModelOutputWithPast(
1185
+ last_hidden_state=hidden_states,
1186
+ past_key_values=next_cache,
1187
+ hidden_states=all_hidden_states,
1188
+ attentions=all_self_attns,
1189
+ )
1190
+
1191
+
1192
+ @add_start_docstrings(
1193
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
1194
+ OPT_START_DOCSTRING,
1195
+ )
1196
+ class OPTModel(OPTPreTrainedModel):
1197
+ def __init__(self, config: OPTConfig):
1198
+ super().__init__(config)
1199
+ self.decoder = OPTDecoder(config)
1200
+ # Initialize weights and apply final processing
1201
+ self.post_init()
1202
+
1203
+ def get_input_embeddings(self):
1204
+ return self.decoder.embed_tokens
1205
+
1206
+ def set_input_embeddings(self, value):
1207
+ self.decoder.embed_tokens = value
1208
+
1209
+ def get_decoder(self):
1210
+ return self.decoder
1211
+
1212
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1213
+ @add_code_sample_docstrings(
1214
+ checkpoint=_CHECKPOINT_FOR_DOC,
1215
+ output_type=BaseModelOutputWithPast,
1216
+ config_class=_CONFIG_FOR_DOC,
1217
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1218
+ )
1219
+ def forward(
1220
+ self,
1221
+ input_ids: torch.LongTensor = None,
1222
+ attention_mask: Optional[torch.Tensor] = None,
1223
+ head_mask: Optional[torch.Tensor] = None,
1224
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1225
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1226
+ use_cache: Optional[bool] = None,
1227
+ output_attentions: Optional[bool] = None,
1228
+ output_hidden_states: Optional[bool] = None,
1229
+ return_dict: Optional[bool] = None,
1230
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1231
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1232
+ output_hidden_states = (
1233
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1234
+ )
1235
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1237
+
1238
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1239
+ decoder_outputs = self.decoder(
1240
+ input_ids=input_ids,
1241
+ attention_mask=attention_mask,
1242
+ head_mask=head_mask,
1243
+ past_key_values=past_key_values,
1244
+ inputs_embeds=inputs_embeds,
1245
+ use_cache=use_cache,
1246
+ output_attentions=output_attentions,
1247
+ output_hidden_states=output_hidden_states,
1248
+ return_dict=return_dict,
1249
+ )
1250
+
1251
+ if not return_dict:
1252
+ return decoder_outputs
1253
+
1254
+ return BaseModelOutputWithPast(
1255
+ last_hidden_state=decoder_outputs.last_hidden_state,
1256
+ past_key_values=decoder_outputs.past_key_values,
1257
+ hidden_states=decoder_outputs.hidden_states,
1258
+ attentions=decoder_outputs.attentions,
1259
+ )
1260
+
1261
+
1262
+ class OPTForCausalLM(OPTPreTrainedModel):
1263
+ _tied_weights_keys = ["lm_head.weight"]
1264
+
1265
+ def __init__(self, config):
1266
+ super().__init__(config)
1267
+ self.model = OPTModel(config)
1268
+
1269
+ # the lm_head weight is automatically tied to the embed tokens weight
1270
+ self.lm_head = nn.Linear(
1271
+ config.word_embed_proj_dim, config.vocab_size, bias=False)
1272
+
1273
+ # Initialize weights and apply final processing
1274
+ self.post_init()
1275
+
1276
+ def get_input_embeddings(self):
1277
+ return self.model.decoder.embed_tokens
1278
+
1279
+ def set_input_embeddings(self, value):
1280
+ self.model.decoder.embed_tokens = value
1281
+
1282
+ def get_output_embeddings(self):
1283
+ return self.lm_head
1284
+
1285
+ def set_output_embeddings(self, new_embeddings):
1286
+ self.lm_head = new_embeddings
1287
+
1288
+ def set_decoder(self, decoder):
1289
+ self.model.decoder = decoder
1290
+
1291
+ def get_decoder(self):
1292
+ return self.model.decoder
1293
+
1294
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1295
+ def forward(
1296
+ self,
1297
+ input_ids: torch.LongTensor = None,
1298
+ attention_mask: Optional[torch.Tensor] = None,
1299
+ head_mask: Optional[torch.Tensor] = None,
1300
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1301
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1302
+ labels: Optional[torch.LongTensor] = None,
1303
+ use_cache: Optional[bool] = None,
1304
+ output_attentions: Optional[bool] = None,
1305
+ output_hidden_states: Optional[bool] = None,
1306
+ return_dict: Optional[bool] = None,
1307
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1308
+ r"""
1309
+ Args:
1310
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1311
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1312
+ provide it.
1313
+
1314
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1315
+ [`PreTrainedTokenizer.__call__`] for details.
1316
+
1317
+ [What are input IDs?](../glossary#input-ids)
1318
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1319
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1320
+
1321
+ - 1 for tokens that are **not masked**,
1322
+ - 0 for tokens that are **masked**.
1323
+
1324
+ [What are attention masks?](../glossary#attention-mask)
1325
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
1326
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1327
+
1328
+ - 1 indicates the head is **not masked**,
1329
+ - 0 indicates the head is **masked**.
1330
+
1331
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1332
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1333
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1334
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
1335
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
1336
+
1337
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1338
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1339
+
1340
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1341
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1342
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1343
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1344
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1345
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1346
+ than the model's internal embedding lookup matrix.
1347
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1348
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1349
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1350
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1351
+ use_cache (`bool`, *optional*):
1352
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1353
+ (see `past_key_values`).
1354
+ output_attentions (`bool`, *optional*):
1355
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1356
+ returned tensors for more detail.
1357
+ output_hidden_states (`bool`, *optional*):
1358
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1359
+ for more detail.
1360
+ return_dict (`bool`, *optional*):
1361
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1362
+
1363
+ Returns:
1364
+
1365
+ Example:
1366
+
1367
+ ```python
1368
+ >>> from transformers import AutoTokenizer, OPTForCausalLM
1369
+
1370
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
1371
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
1372
+
1373
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1374
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1375
+
1376
+ >>> # Generate
1377
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1378
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1379
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
1380
+ ```"""
1381
+
1382
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1383
+ output_hidden_states = (
1384
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1385
+ )
1386
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1387
+
1388
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1389
+ outputs = self.model.decoder(
1390
+ input_ids=input_ids,
1391
+ attention_mask=attention_mask,
1392
+ head_mask=head_mask,
1393
+ past_key_values=past_key_values,
1394
+ inputs_embeds=inputs_embeds,
1395
+ use_cache=use_cache,
1396
+ output_attentions=output_attentions,
1397
+ output_hidden_states=output_hidden_states,
1398
+ return_dict=return_dict,
1399
+ )
1400
+
1401
+ logits = self.lm_head(outputs[0]).contiguous()
1402
+
1403
+ loss = None
1404
+ if labels is not None:
1405
+ # move labels to correct device to enable model parallelism
1406
+ labels = labels.to(logits.device)
1407
+ # Shift so that tokens < n predict n
1408
+ shift_logits = logits[..., :-1, :].contiguous()
1409
+ shift_labels = labels[..., 1:].contiguous()
1410
+ # Flatten the tokens
1411
+ loss_fct = CrossEntropyLoss()
1412
+ loss = loss_fct(
1413
+ shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
1414
+
1415
+ if not return_dict:
1416
+ output = (logits,) + outputs[1:]
1417
+ return (loss,) + output if loss is not None else output
1418
+
1419
+ return CausalLMOutputWithPast(
1420
+ loss=loss,
1421
+ logits=logits,
1422
+ past_key_values=outputs.past_key_values,
1423
+ hidden_states=outputs.hidden_states,
1424
+ attentions=outputs.attentions,
1425
+ )
1426
+
1427
+ def prepare_inputs_for_generation(
1428
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1429
+ ):
1430
+ if past_key_values is not None:
1431
+ past_length = past_key_values[0][0].shape[2]
1432
+
1433
+ # Some generation methods already pass only the last input ID
1434
+ if input_ids.shape[1] > past_length:
1435
+ remove_prefix_length = past_length
1436
+ else:
1437
+ # Default to old behavior: keep only final ID
1438
+ remove_prefix_length = input_ids.shape[1] - 1
1439
+
1440
+ input_ids = input_ids[:, remove_prefix_length:]
1441
+
1442
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1443
+ if inputs_embeds is not None and past_key_values is None:
1444
+ model_inputs = {"inputs_embeds": inputs_embeds}
1445
+ else:
1446
+ model_inputs = {"input_ids": input_ids}
1447
+
1448
+ model_inputs.update(
1449
+ {
1450
+ "past_key_values": past_key_values,
1451
+ "use_cache": kwargs.get("use_cache"),
1452
+ "attention_mask": attention_mask,
1453
+ }
1454
+ )
1455
+ return model_inputs
1456
+
1457
+ @staticmethod
1458
+ def _reorder_cache(past_key_values, beam_idx):
1459
+ reordered_past = ()
1460
+ for layer_past in past_key_values:
1461
+ reordered_past += (
1462
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device))
1463
+ for past_state in layer_past),
1464
+ )
1465
+ return reordered_past
1466
+
1467
+
1468
+ @add_start_docstrings(
1469
+ """
1470
+ The OPT Model transformer with a sequence classification head on top (linear layer).
1471
+
1472
+ [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1473
+ (e.g. GPT-2) do.
1474
+
1475
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1476
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1477
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1478
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1479
+ each row of the batch).
1480
+ """,
1481
+ OPT_START_DOCSTRING,
1482
+ )
1483
+ class OPTForSequenceClassification(OPTPreTrainedModel):
1484
+ def __init__(self, config: OPTConfig):
1485
+ super().__init__(config)
1486
+ self.num_labels = config.num_labels
1487
+ self.model = OPTModel(config)
1488
+ self.score = nn.Linear(config.word_embed_proj_dim,
1489
+ self.num_labels, bias=False)
1490
+
1491
+ # Initialize weights and apply final processing
1492
+ self.post_init()
1493
+
1494
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1495
+ @add_code_sample_docstrings(
1496
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1497
+ output_type=SequenceClassifierOutputWithPast,
1498
+ config_class=_CONFIG_FOR_DOC,
1499
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1500
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1501
+ )
1502
+ def forward(
1503
+ self,
1504
+ input_ids: Optional[torch.LongTensor] = None,
1505
+ attention_mask: Optional[torch.FloatTensor] = None,
1506
+ head_mask: Optional[torch.FloatTensor] = None,
1507
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1508
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1509
+ labels: Optional[torch.LongTensor] = None,
1510
+ use_cache: Optional[bool] = None,
1511
+ output_attentions: Optional[bool] = None,
1512
+ output_hidden_states: Optional[bool] = None,
1513
+ return_dict: Optional[bool] = None,
1514
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1515
+ r"""
1516
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1517
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1518
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1519
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1520
+ """
1521
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1522
+
1523
+ transformer_outputs = self.model(
1524
+ input_ids,
1525
+ past_key_values=past_key_values,
1526
+ attention_mask=attention_mask,
1527
+ head_mask=head_mask,
1528
+ inputs_embeds=inputs_embeds,
1529
+ use_cache=use_cache,
1530
+ output_attentions=output_attentions,
1531
+ output_hidden_states=output_hidden_states,
1532
+ return_dict=return_dict,
1533
+ )
1534
+ hidden_states = transformer_outputs[0]
1535
+ logits = self.score(hidden_states)
1536
+
1537
+ if input_ids is not None:
1538
+ batch_size, sequence_length = input_ids.shape[:2]
1539
+ else:
1540
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1541
+
1542
+ if self.config.pad_token_id is None:
1543
+ sequence_lengths = -1
1544
+ else:
1545
+ if input_ids is not None:
1546
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1547
+ sequence_lengths = torch.eq(
1548
+ input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1549
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1550
+ sequence_lengths = sequence_lengths.to(logits.device)
1551
+ else:
1552
+ sequence_lengths = -1
1553
+ logger.warning(
1554
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1555
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1556
+ )
1557
+
1558
+ pooled_logits = logits[torch.arange(
1559
+ batch_size, device=logits.device), sequence_lengths]
1560
+
1561
+ loss = None
1562
+ if labels is not None:
1563
+ if self.config.problem_type is None:
1564
+ if self.num_labels == 1:
1565
+ self.config.problem_type = "regression"
1566
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1567
+ self.config.problem_type = "single_label_classification"
1568
+ else:
1569
+ self.config.problem_type = "multi_label_classification"
1570
+
1571
+ if self.config.problem_type == "regression":
1572
+ loss_fct = MSELoss()
1573
+ if self.num_labels == 1:
1574
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1575
+ else:
1576
+ loss = loss_fct(pooled_logits, labels)
1577
+ elif self.config.problem_type == "single_label_classification":
1578
+ loss_fct = CrossEntropyLoss()
1579
+ loss = loss_fct(
1580
+ pooled_logits.view(-1, self.num_labels), labels.view(-1))
1581
+ elif self.config.problem_type == "multi_label_classification":
1582
+ loss_fct = BCEWithLogitsLoss()
1583
+ loss = loss_fct(pooled_logits, labels)
1584
+ if not return_dict:
1585
+ output = (pooled_logits,) + transformer_outputs[1:]
1586
+ return ((loss,) + output) if loss is not None else output
1587
+
1588
+ return SequenceClassifierOutputWithPast(
1589
+ loss=loss,
1590
+ logits=pooled_logits,
1591
+ past_key_values=transformer_outputs.past_key_values,
1592
+ hidden_states=transformer_outputs.hidden_states,
1593
+ attentions=transformer_outputs.attentions,
1594
+ )
1595
+
1596
+ def get_input_embeddings(self):
1597
+ return self.model.decoder.embed_tokens
1598
+
1599
+ def set_input_embeddings(self, value):
1600
+ self.model.decoder.embed_tokens = value
1601
+
1602
+
1603
+ @add_start_docstrings(
1604
+ """
1605
+ The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD
1606
+ (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1607
+ """,
1608
+ OPT_START_DOCSTRING,
1609
+ )
1610
+ class OPTForQuestionAnswering(OPTPreTrainedModel):
1611
+ def __init__(self, config: OPTConfig):
1612
+ super().__init__(config)
1613
+ self.model = OPTModel(config)
1614
+ self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2)
1615
+
1616
+ # Initialize weights and apply final processing
1617
+ self.post_init()
1618
+
1619
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1620
+ @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
1621
+ def forward(
1622
+ self,
1623
+ input_ids: Optional[torch.LongTensor] = None,
1624
+ attention_mask: Optional[torch.FloatTensor] = None,
1625
+ head_mask: Optional[torch.FloatTensor] = None,
1626
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1627
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1628
+ start_positions: Optional[torch.LongTensor] = None,
1629
+ end_positions: Optional[torch.LongTensor] = None,
1630
+ use_cache: Optional[bool] = None,
1631
+ output_attentions: Optional[bool] = None,
1632
+ output_hidden_states: Optional[bool] = None,
1633
+ return_dict: Optional[bool] = None,
1634
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1635
+ r"""
1636
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1637
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1638
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1639
+ are not taken into account for computing the loss.
1640
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1641
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1642
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1643
+ are not taken into account for computing the loss.
1644
+
1645
+ Returns:
1646
+
1647
+ Example:
1648
+
1649
+ ```python
1650
+ >>> from transformers import AutoTokenizer, OPTForQuestionAnswering
1651
+ >>> import torch
1652
+
1653
+ >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
1654
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
1655
+
1656
+ >>> # note: we are loading a OPTForQuestionAnswering from the hub here,
1657
+ >>> # so the head will be randomly initialized, hence the predictions will be random
1658
+ >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
1659
+
1660
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
1661
+
1662
+ >>> inputs = tokenizer(question, text, return_tensors="pt")
1663
+ >>> with torch.no_grad():
1664
+ ... outputs = model(**inputs)
1665
+
1666
+ >>> answer_start_index = outputs.start_logits.argmax()
1667
+ >>> answer_end_index = outputs.end_logits.argmax()
1668
+
1669
+ >>> answer_offset = len(tokenizer(question)[0])
1670
+
1671
+ >>> predict_answer_tokens = inputs.input_ids[
1672
+ ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1
1673
+ ... ]
1674
+ >>> predicted = tokenizer.decode(predict_answer_tokens)
1675
+ >>> predicted
1676
+ ' a nice puppet'
1677
+ ```"""
1678
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1679
+
1680
+ transformer_outputs = self.model(
1681
+ input_ids,
1682
+ past_key_values=past_key_values,
1683
+ attention_mask=attention_mask,
1684
+ head_mask=head_mask,
1685
+ inputs_embeds=inputs_embeds,
1686
+ use_cache=use_cache,
1687
+ output_attentions=output_attentions,
1688
+ output_hidden_states=output_hidden_states,
1689
+ return_dict=return_dict,
1690
+ )
1691
+ hidden_states = transformer_outputs[0]
1692
+
1693
+ logits = self.qa_outputs(hidden_states)
1694
+ start_logits, end_logits = logits.split(1, dim=-1)
1695
+ start_logits = start_logits.squeeze(-1).contiguous()
1696
+ end_logits = end_logits.squeeze(-1).contiguous()
1697
+
1698
+ total_loss = None
1699
+ if start_positions is not None and end_positions is not None:
1700
+ # If we are on multi-GPU, split add a dimension
1701
+ if len(start_positions.size()) > 1:
1702
+ start_positions = start_positions.squeeze(-1)
1703
+ if len(end_positions.size()) > 1:
1704
+ end_positions = end_positions.squeeze(-1)
1705
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1706
+ ignored_index = start_logits.size(1)
1707
+ start_positions = start_positions.clamp(
1708
+ 0, ignored_index).to(logits.device)
1709
+ end_positions = end_positions.clamp(
1710
+ 0, ignored_index).to(logits.device)
1711
+
1712
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1713
+ start_loss = loss_fct(start_logits, start_positions)
1714
+ end_loss = loss_fct(end_logits, end_positions)
1715
+ total_loss = (start_loss + end_loss) / 2
1716
+
1717
+ if not return_dict:
1718
+ output = (start_logits, end_logits) + transformer_outputs[2:]
1719
+ return ((total_loss,) + output) if total_loss is not None else output
1720
+
1721
+ return QuestionAnsweringModelOutput(
1722
+ loss=total_loss,
1723
+ start_logits=start_logits,
1724
+ end_logits=end_logits,
1725
+ hidden_states=transformer_outputs.hidden_states,
1726
+ attentions=transformer_outputs.attentions,
1727
+ )
1728
+
1729
+ def get_input_embeddings(self):
1730
+ return self.model.decoder.embed_tokens
1731
+
1732
+ def set_input_embeddings(self, value):
1733
+ self.model.decoder.embed_tokens = value