Question Answering
Transformers
Safetensors
English
doge
text-generation
custom_code
JingzeShi commited on
Commit
067141a
verified
1 Parent(s): 990fb24

Upload DogeForCausalLM

Browse files
Files changed (5) hide show
  1. config.json +47 -37
  2. configuration_doge.py +83 -46
  3. generation_config.json +7 -7
  4. model.safetensors +2 -2
  5. modeling_doge.py +382 -256
config.json CHANGED
@@ -1,37 +1,47 @@
1
- {
2
- "_name_or_path": "./results/Doge-60M-Instruct",
3
- "architectures": [
4
- "DogeForCausalLM"
5
- ],
6
- "attention_dropout": 0.0,
7
- "auto_map": {
8
- "AutoConfig": "configuration_doge.DogeConfig",
9
- "AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
10
- },
11
- "bos_token_id": 1,
12
- "eos_token_id": 2,
13
- "expert_retrieval_size": 256,
14
- "hidden_act": "silu",
15
- "hidden_bias": false,
16
- "hidden_dropout": 0.0,
17
- "hidden_size": 512,
18
- "initializer_range": 0.02,
19
- "intermediate_size": 2048,
20
- "is_moe": false,
21
- "max_position_embeddings": 2048,
22
- "model_type": "doge",
23
- "num_attention_heads": 4,
24
- "num_cdmmoe_experts": 4096,
25
- "num_cdmmoe_experts_per_head": 8,
26
- "num_cdmmoe_heads": 4,
27
- "num_hidden_layers": 8,
28
- "pad_token_id": 0,
29
- "rms_norm_eps": 1e-06,
30
- "rope_scaling": null,
31
- "rope_theta": 10000.0,
32
- "tie_word_embeddings": false,
33
- "torch_dtype": "float32",
34
- "transformers_version": "4.46.1",
35
- "use_cache": true,
36
- "vocab_size": 32768
37
- }
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./results/Doge-60M-Instruct-DPO",
3
+ "architectures": [
4
+ "DogeForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_doge.DogeConfig",
9
+ "AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
10
+ },
11
+ "bos_token_id": 0,
12
+ "dynamic_mask_ratio": 0.0,
13
+ "eos_token_id": 1,
14
+ "expert_retrieval_size": 256,
15
+ "hidden_act": "silu",
16
+ "hidden_bias": false,
17
+ "hidden_dropout": 0.0,
18
+ "hidden_size": 512,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 1024,
21
+ "is_moe": false,
22
+ "max_position_embeddings": 2048,
23
+ "model_type": "doge",
24
+ "num_attention_heads": 4,
25
+ "num_cdmmoe_experts": 2048,
26
+ "num_cdmmoe_experts_per_head": 8,
27
+ "num_cdmmoe_heads": 4,
28
+ "num_cdmoe_experts": 16348,
29
+ "num_cdmoe_experts_per_head": 8,
30
+ "num_cdmoe_heads": 4,
31
+ "num_channels": 3,
32
+ "num_hidden_layers": 16,
33
+ "num_key_value_heads": 2,
34
+ "pad_token_id": 2,
35
+ "patch_size": 16,
36
+ "rms_norm_eps": 1e-06,
37
+ "rope_scaling": {
38
+ "factor": 4.0,
39
+ "original_max_position_embeddings": 2048,
40
+ "rope_type": "dynamic"
41
+ },
42
+ "rope_theta": 10000.0,
43
+ "torch_dtype": "float32",
44
+ "transformers_version": "4.49.0.dev0",
45
+ "use_cache": true,
46
+ "vocab_size": 32768
47
+ }
configuration_doge.py CHANGED
@@ -25,20 +25,23 @@ from transformers.modeling_rope_utils import rope_config_validation
25
  class DogeConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`DogeModel`]. It is used to instantiate an Doge
28
- model according to the specified arguments, defining the model architecture like [LoserCheems/doge-tiny-test](https://huggingface.co/LoserCheems/doge-tiny-test)
29
 
30
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
  documentation from [`PretrainedConfig`] for more information.
32
 
33
  Args:
34
  vocab_size (`int`, *optional*, defaults to 32768):
35
- Vocabulary size of the Doge model. Defines the number of different tokens that can be represented by the
36
- `inputs_ids` passed when calling [`DogeModel`]
 
 
 
37
  hidden_size (`int`, *optional*, defaults to 1024):
38
  Dimension of the hidden representations.
39
- intermediate_size (`int`, *optional*, defaults to 4096):
40
- Dimension of the CDMoE representations.
41
- num_hidden_layers (`int`, *optional*, defaults to 16):
42
  Number of hidden layers in the Transformer decoder.
43
  hidden_bias (`bool`, *optional*, defaults to `False`):
44
  Whether to use bias in the hidden layers.
@@ -51,24 +54,21 @@ class DogeConfig(PretrainedConfig):
51
  rope_theta (`float`, *optional*, defaults to 10000.0):
52
  The base period of the RoPE embeddings.
53
  rope_scaling (`Dict`, *optional*):
54
- Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
55
- and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
56
- accordingly.
57
  Expected contents:
58
  `rope_type` (`str`):
59
- The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
60
- 'llama3'], with 'default' being the original RoPE implementation.
61
  `factor` (`float`, *optional*):
62
- Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
63
- most scaling types, a `factor` of x will enable the model to handle sequences of length x *
64
- original maximum pre-trained length.
65
  `original_max_position_embeddings` (`int`, *optional*):
66
- Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
67
- pretraining.
68
  `attention_factor` (`float`, *optional*):
69
  Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
70
- computation. If unspecified, it defaults to value recommended by the implementation, using the
71
- `factor` field to infer the suggested value.
72
  `beta_fast` (`float`, *optional*):
73
  Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
74
  ramp function. If unspecified, it defaults to 32.
@@ -76,13 +76,11 @@ class DogeConfig(PretrainedConfig):
76
  Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
77
  ramp function. If unspecified, it defaults to 1.
78
  `short_factor` (`List[float]`, *optional*):
79
- Only used with 'longrope'. The scaling factor to be applied to short contexts (<
80
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
81
- size divided by the number of attention heads divided by 2
82
  `long_factor` (`List[float]`, *optional*):
83
- Only used with 'longrope'. The scaling factor to be applied to long contexts (<
84
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
85
- size divided by the number of attention heads divided by 2
86
  `low_freq_factor` (`float`, *optional*):
87
  Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
88
  `high_freq_factor` (`float`, *optional*):
@@ -100,56 +98,86 @@ class DogeConfig(PretrainedConfig):
100
  Beginning of stream token id.
101
  eos_token_id (`int`, *optional*, defaults to 2):
102
  End of stream token id.
103
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
104
  Whether to tie weight embeddings
105
  num_attention_heads (`int`, *optional*, defaults to 8):
106
  Number of attention heads for each attention layer in the Transformer decoder.
 
 
 
 
 
 
 
107
  attention_dropout (`float`, *optional*, defaults to 0.0):
108
  The dropout ratio for the attention probabilities.
 
 
109
  is_moe (`bool`, *optional*, defaults to `False`):
110
  Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
111
- num_cdmmoe_experts (`int`, *optional*, defaults to 4096):
112
- Number of Private Experts for the Cross Domain Mixture of Experts.
113
- num_cdmmoe_heads (`int`, *optional*, defaults to 4):
114
  Number of heads of Private Experts for the Cross Domain Mixture of Experts.
115
- num_cdmmoe_experts_per_head (`int`, *optional*, defaults to 8):
116
  Number of Private Experts per head for the Cross Domain Mixture of Experts.
117
- expert_retrieval_size (`int`, *optional*, defaults to 256):
118
  Dimension of the Expert retrieval states for the Cross Domain Mixture of Experts.
119
  """
120
 
121
  model_type = "doge"
122
  keys_to_ignore_at_inference = ["past_key_values"]
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  def __init__(
125
  self,
126
  vocab_size=32768,
 
 
127
  hidden_size=1024,
128
- intermediate_size=4096,
129
- num_hidden_layers=16,
130
  hidden_bias=False,
131
  hidden_dropout=0.0,
132
  hidden_act="silu",
133
  max_position_embeddings=2048,
134
  rope_theta=10000.0,
135
- rope_scaling=None,
 
 
 
 
136
  initializer_range=0.02,
137
  rms_norm_eps=1e-06,
138
  use_cache=True,
139
- pad_token_id=0,
140
- bos_token_id=1,
141
- eos_token_id=2,
142
- tie_word_embeddings=False,
143
  num_attention_heads=8,
 
144
  attention_dropout=0.0,
 
145
  is_moe=False,
146
- num_cdmmoe_experts=4096,
147
- num_cdmmoe_heads=4,
148
- num_cdmmoe_experts_per_head=8,
149
- expert_retrieval_size=256,
150
  **kwargs,
151
  ):
152
  self.vocab_size = vocab_size
 
 
153
  self.hidden_size = hidden_size
154
  self.intermediate_size = intermediate_size
155
  self.num_hidden_layers = num_hidden_layers
@@ -162,16 +190,18 @@ class DogeConfig(PretrainedConfig):
162
  self.initializer_range = initializer_range
163
  self.rms_norm_eps = rms_norm_eps
164
  self.use_cache = use_cache
165
- self.pad_token_id = pad_token_id
166
  self.bos_token_id = bos_token_id
167
  self.eos_token_id = eos_token_id
 
168
  self.tie_word_embeddings = tie_word_embeddings
169
  self.num_attention_heads = num_attention_heads
 
170
  self.attention_dropout = attention_dropout
 
171
  self.is_moe = is_moe
172
- self.num_cdmmoe_experts = num_cdmmoe_experts
173
- self.num_cdmmoe_heads = num_cdmmoe_heads
174
- self.num_cdmmoe_experts_per_head = num_cdmmoe_experts_per_head
175
  self.expert_retrieval_size = expert_retrieval_size
176
 
177
  # Validate the correctness of rotary position embeddings parameters
@@ -180,10 +210,17 @@ class DogeConfig(PretrainedConfig):
180
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
181
  rope_config_validation(self)
182
 
 
 
 
 
183
  super().__init__(
184
- pad_token_id=pad_token_id,
185
  bos_token_id=bos_token_id,
186
  eos_token_id=eos_token_id,
 
187
  tie_word_embeddings=tie_word_embeddings,
188
  **kwargs,
189
  )
 
 
 
 
25
  class DogeConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`DogeModel`]. It is used to instantiate an Doge
28
+ model according to the specified arguments, defining the model architecture like [JingzeShi/Doge-20M](https://huggingface.co/JingzeShi/Doge-20M).
29
 
30
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
  documentation from [`PretrainedConfig`] for more information.
32
 
33
  Args:
34
  vocab_size (`int`, *optional*, defaults to 32768):
35
+ Vocabulary size of the Doge model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`DogeModel`]
36
+ num_channels (`int`, *optional*, defaults to 3):
37
+ Number of channels in the input image.
38
+ patch_size (`int`, *optional*, defaults to 16):
39
+ Patch size of Vision Transformer Embeddings.
40
  hidden_size (`int`, *optional*, defaults to 1024):
41
  Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 2048):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
  Number of hidden layers in the Transformer decoder.
46
  hidden_bias (`bool`, *optional*, defaults to `False`):
47
  Whether to use bias in the hidden layers.
 
54
  rope_theta (`float`, *optional*, defaults to 10000.0):
55
  The base period of the RoPE embeddings.
56
  rope_scaling (`Dict`, *optional*):
57
+ Dictionary containing the scaling configuration for the RoPE embeddings.
58
+ NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly.
 
59
  Expected contents:
60
  `rope_type` (`str`):
61
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation.
 
62
  `factor` (`float`, *optional*):
63
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings.
64
+ In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length.
 
65
  `original_max_position_embeddings` (`int`, *optional*):
66
+ Used with 'dynamic', 'longrope' and 'llama3'.
67
+ The original max position embeddings used during pretraining.
68
  `attention_factor` (`float`, *optional*):
69
  Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
70
+ computation.
71
+ If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
72
  `beta_fast` (`float`, *optional*):
73
  Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
74
  ramp function. If unspecified, it defaults to 32.
 
76
  Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
77
  ramp function. If unspecified, it defaults to 1.
78
  `short_factor` (`List[float]`, *optional*):
79
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<`original_max_position_embeddings`).
80
+ Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
 
81
  `long_factor` (`List[float]`, *optional*):
82
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<`original_max_position_embeddings`).
83
+ Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
 
84
  `low_freq_factor` (`float`, *optional*):
85
  Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
86
  `high_freq_factor` (`float`, *optional*):
 
98
  Beginning of stream token id.
99
  eos_token_id (`int`, *optional*, defaults to 2):
100
  End of stream token id.
101
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
102
  Whether to tie weight embeddings
103
  num_attention_heads (`int`, *optional*, defaults to 8):
104
  Number of attention heads for each attention layer in the Transformer decoder.
105
+ num_key_value_heads (`int`, *optional*, defaults to `None`):
106
+ This is the number of key_value heads that should be used to implement Grouped Query Attention.
107
+ If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
108
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used.
109
+ When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group.
110
+ For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf).
111
+ If it is not specified, will default to `num_attention_heads`.
112
  attention_dropout (`float`, *optional*, defaults to 0.0):
113
  The dropout ratio for the attention probabilities.
114
+ dynamic_mask_ratio (`float`, *optional*, defaults to 0.0, range [0, 1]):
115
+ The ratio to control the proportion of the dynamic mask filled with the minimum value.
116
  is_moe (`bool`, *optional*, defaults to `False`):
117
  Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
118
+ num_cdmoe_experts (`int`, *optional*, defaults to 16348):
119
+ Number of Private Experts for the Cross Domain Mixture of Experts. calculation formula: :math:`\text{num_cdmoe_experts} = (32 \times \text{num_cdmoe_heads})^2`
120
+ num_cdmoe_heads (`int`, *optional*, defaults to 4):
121
  Number of heads of Private Experts for the Cross Domain Mixture of Experts.
122
+ num_cdmoe_experts_per_head (`int`, *optional*, defaults to 8):
123
  Number of Private Experts per head for the Cross Domain Mixture of Experts.
124
+ expert_retrieval_size (`int`, *optional*, defaults to 64):
125
  Dimension of the Expert retrieval states for the Cross Domain Mixture of Experts.
126
  """
127
 
128
  model_type = "doge"
129
  keys_to_ignore_at_inference = ["past_key_values"]
130
+ # Default tensor parallel plan for base model `DogeModel`
131
+ base_model_tp_plan = {
132
+ "layers.*.self_attn.q_proj": "colwise",
133
+ "layers.*.self_attn.k_proj": "colwise",
134
+ "layers.*.self_attn.v_proj": "colwise",
135
+ "layers.*.self_attn.dt_proj": "colwise",
136
+ "layers.*.self_attn.o_proj": "rowwise",
137
+ "layers.*.mlp.gate_proj": "colwise",
138
+ "layers.*.mlp.up_proj": "colwise",
139
+ "layers.*.mlp.down_proj": "rowwise",
140
+ }
141
 
142
  def __init__(
143
  self,
144
  vocab_size=32768,
145
+ num_channels=3,
146
+ patch_size=16,
147
  hidden_size=1024,
148
+ intermediate_size=2048,
149
+ num_hidden_layers=32,
150
  hidden_bias=False,
151
  hidden_dropout=0.0,
152
  hidden_act="silu",
153
  max_position_embeddings=2048,
154
  rope_theta=10000.0,
155
+ rope_scaling={
156
+ "rope_type": "dynamic",
157
+ "factor": 4.0,
158
+ "original_max_position_embeddings": 2048,
159
+ },
160
  initializer_range=0.02,
161
  rms_norm_eps=1e-06,
162
  use_cache=True,
163
+ bos_token_id=0,
164
+ eos_token_id=1,
165
+ pad_token_id=2,
166
+ tie_word_embeddings=True,
167
  num_attention_heads=8,
168
+ num_key_value_heads=None,
169
  attention_dropout=0.0,
170
+ dynamic_mask_ratio=0.0,
171
  is_moe=False,
172
+ num_cdmoe_experts=16348,
173
+ num_cdmoe_heads=4,
174
+ num_cdmoe_experts_per_head=8,
175
+ expert_retrieval_size=64,
176
  **kwargs,
177
  ):
178
  self.vocab_size = vocab_size
179
+ self.num_channels = num_channels
180
+ self.patch_size = patch_size
181
  self.hidden_size = hidden_size
182
  self.intermediate_size = intermediate_size
183
  self.num_hidden_layers = num_hidden_layers
 
190
  self.initializer_range = initializer_range
191
  self.rms_norm_eps = rms_norm_eps
192
  self.use_cache = use_cache
 
193
  self.bos_token_id = bos_token_id
194
  self.eos_token_id = eos_token_id
195
+ self.pad_token_id = pad_token_id
196
  self.tie_word_embeddings = tie_word_embeddings
197
  self.num_attention_heads = num_attention_heads
198
+ self.num_key_value_heads = num_key_value_heads
199
  self.attention_dropout = attention_dropout
200
+ self.dynamic_mask_ratio = dynamic_mask_ratio
201
  self.is_moe = is_moe
202
+ self.num_cdmoe_experts = num_cdmoe_experts
203
+ self.num_cdmoe_heads = num_cdmoe_heads
204
+ self.num_cdmoe_experts_per_head = num_cdmoe_experts_per_head
205
  self.expert_retrieval_size = expert_retrieval_size
206
 
207
  # Validate the correctness of rotary position embeddings parameters
 
210
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
211
  rope_config_validation(self)
212
 
213
+ # for backward compatibility
214
+ if num_key_value_heads is None:
215
+ self.num_key_value_heads = num_attention_heads
216
+
217
  super().__init__(
 
218
  bos_token_id=bos_token_id,
219
  eos_token_id=eos_token_id,
220
+ pad_token_id=pad_token_id,
221
  tie_word_embeddings=tie_word_embeddings,
222
  **kwargs,
223
  )
224
+
225
+
226
+ __all__ = ["DogeConfig"]
generation_config.json CHANGED
@@ -1,7 +1,7 @@
1
- {
2
- "_from_model_config": true,
3
- "bos_token_id": 1,
4
- "eos_token_id": 2,
5
- "pad_token_id": 0,
6
- "transformers_version": "4.46.1"
7
- }
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.49.0.dev0"
7
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bb865063ca9536a49054948ea12f7d90519ee363ee98c5fff7bc2de6f82e0a86
3
- size 268580408
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d30a2a446050f4e9c26bb833e260e5479937577b280c16d1e39f8ce4e66aba1
3
+ size 218325576
modeling_doge.py CHANGED
@@ -19,7 +19,7 @@
19
  """PyTorch Doge model."""
20
 
21
  import math
22
- from typing import List, Optional, Tuple, Union
23
 
24
  import torch
25
  import torch.nn.functional as F
@@ -36,9 +36,12 @@ from transformers.modeling_outputs import (
36
  )
37
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
38
  from transformers.modeling_utils import PreTrainedModel
 
39
  from transformers.utils import (
 
40
  add_start_docstrings,
41
  add_start_docstrings_to_model_forward,
 
42
  logging,
43
  replace_return_docstrings,
44
  )
@@ -49,6 +52,9 @@ try:
49
  except ImportError:
50
  einx_add = None
51
 
 
 
 
52
 
53
  logger = logging.get_logger(__name__)
54
 
@@ -79,7 +85,7 @@ class Residual(nn.Module):
79
  def __init__(self, hidden_size):
80
  super().__init__()
81
  self.weight = nn.Parameter(torch.ones(hidden_size))
82
-
83
  def forward(self, residual_states, hidden_states):
84
  return self.weight * residual_states + hidden_states
85
 
@@ -92,10 +98,10 @@ class RotaryEmbedding(nn.Module):
92
  super().__init__()
93
  self.rope_kwargs = {}
94
 
95
- if config.rope_scaling is None:
96
- self.rope_type = "default"
97
  else:
98
- self.rope_type = config.rope_scaling
99
  self.max_seq_len_cached = config.max_position_embeddings
100
  self.original_max_seq_len = config.max_position_embeddings
101
  self.base = config.rope_theta
@@ -133,6 +139,7 @@ class RotaryEmbedding(nn.Module):
133
  # core RoPE block
134
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
135
  position_ids_expanded = position_ids[:, None, :].float()
 
136
  device_type = x.device.type
137
  device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
138
  with torch.autocast(device_type=device_type, enabled=False):
@@ -141,6 +148,7 @@ class RotaryEmbedding(nn.Module):
141
  cos = emb.cos()
142
  sin = emb.sin()
143
 
 
144
  cos = cos * self.attention_scaling
145
  sin = sin * self.attention_scaling
146
 
@@ -168,11 +176,10 @@ def apply_QK_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
168
  Deprecated and unused.
169
  unsqueeze_dim (`int`, *optional*, defaults to 1):
170
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
171
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
172
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
173
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
174
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
175
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
176
  Returns:
177
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
178
  """
@@ -183,82 +190,83 @@ def apply_QK_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
183
  return q_embed, k_embed
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  class DogeDynamicMaskAttention(nn.Module):
187
  """Dynamic Mask Attention from 'Wonderful Matrices' paper."""
188
 
189
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
190
  super().__init__()
191
-
192
  self.config = config
193
  self.layer_idx = layer_idx
194
- if layer_idx is None:
195
- logger.warning_once(
196
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
197
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
198
- "when creating this class."
199
- )
200
-
201
- self.hidden_dim = config.hidden_size
202
- self.num_attention_heads = config.num_attention_heads
203
  self.attention_dropout = config.attention_dropout
204
- self.attention_head_dim = self.hidden_dim // self.num_attention_heads
 
 
 
 
 
 
205
 
206
  # Q K V O projections
207
  self.q_proj = nn.Linear(
208
- self.hidden_dim,
209
- self.num_attention_heads * self.attention_head_dim,
210
- bias=config.hidden_bias,
211
  )
212
  self.k_proj = nn.Linear(
213
- self.hidden_dim,
214
- self.num_attention_heads * self.attention_head_dim,
215
- bias=config.hidden_bias,
 
 
 
 
 
216
  )
217
  # dynamic mask for the QK^T attention score matrix
218
  self.A = nn.Parameter(
219
- torch.ones(self.num_attention_heads)
220
  )
221
  self.dt_proj = nn.Linear(
222
- self.hidden_dim,
223
- self.num_attention_heads,
224
- bias=config.hidden_bias,
225
- )
226
- self.v_proj = nn.Linear(
227
- self.hidden_dim,
228
- self.num_attention_heads * self.attention_head_dim,
229
- bias=config.hidden_bias,
230
  )
231
  self.o_proj = nn.Linear(
232
- self.hidden_dim,
233
- self.hidden_dim,
234
- bias=config.hidden_bias,
235
  )
236
 
237
  def forward(
238
  self,
239
  hidden_states: torch.Tensor,
 
240
  attention_mask: Optional[torch.Tensor] = None,
241
- position_ids: Optional[torch.LongTensor] = None,
242
  past_key_value: Optional[Cache] = None,
243
  cache_position: Optional[torch.LongTensor] = None,
244
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
245
  **kwargs,
246
  ) -> Tuple[torch.Tensor, Optional[Cache]]:
247
- bsz, q_len, _ = hidden_states.shape
248
-
249
- query_states = self.q_proj(hidden_states)
250
- key_states = self.k_proj(hidden_states)
251
- value_states = self.v_proj(hidden_states)
252
 
253
- query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(
254
- 1, 2
255
- )
256
- key_states = key_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(
257
- 1, 2
258
- )
259
- value_states = value_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(
260
- 1, 2
261
- )
262
 
263
  cos, sin = position_embeddings
264
  query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -268,90 +276,153 @@ class DogeDynamicMaskAttention(nn.Module):
268
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
269
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
270
 
271
- # compute attention scores matrix
272
- attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.attention_head_dim)
273
-
274
- # add mask to attention scores
275
- if attention_mask is not None:
276
- dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
277
- dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
278
- dynamic_mask = dynamic_mask < 1.0
279
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
280
- attn_weights = attn_weights + causal_mask
281
-
282
- # upcast attention scores to fp32
283
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
284
- attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
285
 
286
- # apply attention scores to value states
287
- attn_output = torch.matmul(attn_weights, value_states)
 
 
 
 
 
 
 
 
 
 
 
288
 
289
- attn_output = attn_output.transpose(1, 2).contiguous()
290
- attn_output = attn_output.reshape(bsz, q_len, -1)
291
  attn_output = self.o_proj(attn_output)
 
292
 
293
- return attn_output, past_key_value
294
-
295
-
296
- class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
297
-
298
- def forward(
299
  self,
300
  hidden_states: torch.Tensor,
 
 
301
  attention_mask: Optional[torch.Tensor] = None,
302
- position_ids: Optional[torch.LongTensor] = None,
303
- past_key_value: Optional[Cache] = None,
304
- cache_position: Optional[torch.LongTensor] = None,
305
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
306
- **kwargs,
307
- ) -> Tuple[torch.Tensor, Optional[Cache]]:
308
- bsz, q_len, _ = hidden_states.shape
309
-
310
- query_states = self.q_proj(hidden_states)
311
- key_states = self.k_proj(hidden_states)
312
- value_states = self.v_proj(hidden_states)
313
-
314
- query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(1, 2)
315
- key_states = key_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(1, 2)
316
- value_states = value_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(1, 2)
317
 
318
- cos, sin = position_embeddings
319
- query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- if past_key_value is not None:
322
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
323
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
324
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  if attention_mask is not None:
327
- dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
328
- dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
329
- dynamic_mask = dynamic_mask < 1.0
330
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
331
 
332
- query_states = query_states.contiguous()
333
- key_states = key_states.contiguous()
334
- value_states = value_states.contiguous()
 
 
335
 
 
 
336
  attn_output = F.scaled_dot_product_attention(
337
- query_states,
338
- key_states,
339
- value_states,
340
  attn_mask=causal_mask,
341
- dropout_p=self.attention_dropout,
 
 
342
  )
343
-
344
  attn_output = attn_output.transpose(1, 2).contiguous()
345
- attn_output = attn_output.view(bsz, q_len, -1)
346
- attn_output = self.o_proj(attn_output)
347
-
348
- return attn_output, past_key_value
349
-
 
 
 
 
 
 
 
 
 
 
350
 
351
- DOGE_ATTENTION_CLASSES = {
352
- "eager": DogeDynamicMaskAttention,
353
- "sdpa": DogeSdpaDynamicMaskAttn,
354
- }
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
 
357
  class DogeMLP(nn.Module):
@@ -362,21 +433,9 @@ class DogeMLP(nn.Module):
362
  self.intermediate_dim = config.intermediate_size
363
  self.act_fn = ACT2FN[config.hidden_act]
364
 
365
- self.gate_proj = nn.Linear(
366
- self.hidden_dim,
367
- self.intermediate_dim,
368
- bias=config.hidden_bias,
369
- )
370
- self.up_proj = nn.Linear(
371
- self.hidden_dim,
372
- self.intermediate_dim,
373
- bias=config.hidden_bias,
374
- )
375
- self.down_proj = nn.Linear(
376
- self.intermediate_dim,
377
- self.hidden_dim,
378
- bias=config.hidden_bias,
379
- )
380
 
381
  def forward(
382
  self,
@@ -396,36 +455,18 @@ class DogeCDMoE(DogeMLP):
396
  self.act_fn = ACT2FN[config.hidden_act]
397
 
398
  self.expert_retrieval_dim = config.expert_retrieval_size
399
- self.num_cdmmoe_experts = config.num_cdmmoe_experts
400
- self.num_cdmmoe_heads = config.num_cdmmoe_heads
401
- self.num_cdmmoe_experts_per_head = config.num_cdmmoe_experts_per_head
402
- self.num_keys = int(math.sqrt(self.num_cdmmoe_experts))
403
 
404
  # queries and keys for retrieval experts
405
- self.queries = nn.Linear(
406
- self.hidden_dim,
407
- self.num_cdmmoe_heads * self.expert_retrieval_dim,
408
- bias=False,
409
- )
410
- self.keys = nn.Parameter(
411
- torch.zeros(
412
- self.num_cdmmoe_heads,
413
- self.num_keys,
414
- 2,
415
- self.expert_retrieval_dim // 2,
416
- )
417
- )
418
 
419
  # experts
420
- self.down_embed = nn.Embedding(
421
- self.num_cdmmoe_experts,
422
- self.hidden_dim,
423
- )
424
- self.up_embed = nn.Embedding(
425
- self.num_cdmmoe_experts,
426
- self.hidden_dim,
427
- )
428
-
429
 
430
  def forward(
431
  self,
@@ -436,11 +477,11 @@ class DogeCDMoE(DogeMLP):
436
 
437
  # get similarity with queries and keys
438
  queries = self.queries(hidden_states)
439
- queries = queries.view(bsz, seq_len, 2, self.num_cdmmoe_heads, -1).permute(2, 0, 1, 3, 4)
440
  sim = torch.einsum("p b t h n, h k p n -> p b t h k", queries, self.keys)
441
 
442
  # get experts with the highest similarity
443
- (scores_x, scores_y), (indices_x, indices_y) = sim.topk(self.num_cdmmoe_experts_per_head, dim=-1)
444
  if einx_add is not None:
445
  all_scores = einx_add("... i, ... j -> ... (i j)", scores_x, scores_y)
446
  all_indices = einx_add("... i, ... j -> ... (i j)", indices_x * self.num_keys, indices_y)
@@ -449,7 +490,7 @@ class DogeCDMoE(DogeMLP):
449
  all_scores = all_scores.view(*scores_x.shape[:-1], -1)
450
  all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
451
  all_indices = all_indices.view(*indices_x.shape[:-1], -1)
452
- scores, pk_indices = all_scores.topk(self.num_cdmmoe_experts_per_head, dim=-1)
453
  indices = all_indices.gather(-1, pk_indices)
454
  down_embed = self.down_embed(indices)
455
  up_embed = self.up_embed(indices)
@@ -468,13 +509,13 @@ class DogeDecoderLayer(nn.Module):
468
  super().__init__()
469
  self.hidden_dropout = config.hidden_dropout
470
 
471
- self.pre_sequence_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
472
- self.attn = DOGE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
473
- self.post_sequence_residual = Residual(config.hidden_size)
474
 
475
- self.pre_state_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
476
  self.feed_forward = DogeMLP(config) if config.is_moe == False else DogeCDMoE(config)
477
- self.post_state_residual = Residual(config.hidden_size)
478
 
479
  def forward(
480
  self,
@@ -485,36 +526,14 @@ class DogeDecoderLayer(nn.Module):
485
  output_attentions: Optional[bool] = False,
486
  use_cache: Optional[bool] = False,
487
  cache_position: Optional[torch.LongTensor] = None,
488
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
489
  **kwargs,
490
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
491
- """
492
- Args:
493
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
494
- attention_mask (`torch.FloatTensor`, *optional*):
495
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
496
- query_sequence_length, key_sequence_length)` if default attention is used.
497
- output_attentions (`bool`, *optional*):
498
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
499
- returned tensors for more detail.
500
- use_cache (`bool`, *optional*):
501
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
502
- (see `past_key_values`).
503
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
504
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
505
- Indices depicting the position of the input sequence tokens in the sequence
506
- position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
507
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
508
- with `head_dim` being the embedding dimension of each attention head.
509
- kwargs (`dict`, *optional*):
510
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
511
- into the model
512
- """
513
 
514
  # sequence transformation
515
  residual = hidden_states
516
- hidden_states = self.pre_sequence_layernorm(hidden_states)
517
- hidden_states, present_key_value = self.attn(
518
  hidden_states=hidden_states,
519
  attention_mask=attention_mask,
520
  position_ids=position_ids,
@@ -525,27 +544,41 @@ class DogeDecoderLayer(nn.Module):
525
  )
526
  self_attn_weights = None
527
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
528
- hidden_states = self.post_sequence_residual(residual, hidden_states)
529
 
530
  # state transformation
531
  residual = hidden_states
532
- hidden_states = self.pre_state_layernorm(hidden_states)
533
  hidden_states = self.feed_forward(hidden_states)
534
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
535
- hidden_states = self.post_state_residual(residual, hidden_states)
536
 
537
  outputs = (hidden_states,)
538
-
539
  if output_attentions:
540
  outputs += (self_attn_weights,)
541
 
542
- if use_cache:
543
- outputs += (present_key_value,)
544
-
545
  return outputs
546
 
547
 
548
- @add_start_docstrings("The bare Doge Model outputting raw hidden-states without any specific head on top.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  class DogePreTrainedModel(PreTrainedModel):
550
  config_class = DogeConfig
551
  base_model_prefix = "model"
@@ -553,6 +586,7 @@ class DogePreTrainedModel(PreTrainedModel):
553
  _no_split_modules = ["DogeDecoderLayer"]
554
  _skip_keys_device_placement = ["past_key_values"]
555
  _supports_sdpa = True
 
556
  _supports_cache_class = True
557
  _supports_quantized_cache = True
558
  _supports_static_cache = True
@@ -644,8 +678,18 @@ DOGE_INPUTS_DOCSTRING = r"""
644
  """
645
 
646
 
647
- @add_start_docstrings("The bare Doge Model outputting raw hidden-states without any specific head on top.")
 
 
 
648
  class DogeModel(DogePreTrainedModel):
 
 
 
 
 
 
 
649
  def __init__(self, config: DogeConfig):
650
  super().__init__(config)
651
  self.config = config
@@ -682,6 +726,7 @@ class DogeModel(DogePreTrainedModel):
682
  output_hidden_states: Optional[bool] = None,
683
  return_dict: Optional[bool] = None,
684
  cache_position: Optional[torch.LongTensor] = None,
 
685
  ) -> Union[Tuple, BaseModelOutputWithPast]:
686
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
687
  output_hidden_states = (
@@ -702,33 +747,22 @@ class DogeModel(DogePreTrainedModel):
702
  if inputs_embeds is None:
703
  inputs_embeds = self.word_embed(input_ids)
704
 
705
- # kept for BC (non `Cache` `past_key_values` inputs)
706
- return_legacy_cache = False
707
- if use_cache and not isinstance(past_key_values, Cache):
708
- return_legacy_cache = True
709
- if past_key_values is None:
710
- past_key_values = DynamicCache()
711
- else:
712
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
713
- logger.warning_once(
714
- "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
715
- "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
716
- "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
717
- )
718
 
719
  if cache_position is None:
720
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
721
  cache_position = torch.arange(
722
- past_seen_tokens,
723
- past_seen_tokens + inputs_embeds.shape[1],
724
- device=inputs_embeds.device,
725
  )
 
726
  if position_ids is None:
727
  position_ids = cache_position.unsqueeze(0)
728
 
729
  causal_mask = self._update_causal_mask(
730
  attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
731
  )
 
732
  hidden_states = inputs_embeds
733
 
734
  # create position embeddings to be shared across the decoder layers
@@ -737,9 +771,8 @@ class DogeModel(DogePreTrainedModel):
737
  # decoder layers
738
  all_hidden_states = () if output_hidden_states else None
739
  all_self_attns = () if output_attentions else None
740
- next_decoder_cache = None
741
 
742
- for decoder_layer in self.layers:
743
  if output_hidden_states:
744
  all_hidden_states += (hidden_states,)
745
 
@@ -765,13 +798,11 @@ class DogeModel(DogePreTrainedModel):
765
  use_cache=use_cache,
766
  cache_position=cache_position,
767
  position_embeddings=position_embeddings,
 
768
  )
769
 
770
  hidden_states = layer_outputs[0]
771
 
772
- if use_cache:
773
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
774
-
775
  if output_attentions:
776
  all_self_attns += (layer_outputs[1],)
777
 
@@ -781,27 +812,21 @@ class DogeModel(DogePreTrainedModel):
781
  if output_hidden_states:
782
  all_hidden_states += (hidden_states,)
783
 
784
- next_cache = next_decoder_cache if use_cache else None
785
- if return_legacy_cache:
786
- next_cache = next_cache.to_legacy_cache()
787
-
788
- if not return_dict:
789
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
790
-
791
- return BaseModelOutputWithPast(
792
  last_hidden_state=hidden_states,
793
- past_key_values=next_cache,
794
  hidden_states=all_hidden_states,
795
  attentions=all_self_attns,
796
  )
 
797
 
798
  def _update_causal_mask(
799
  self,
800
- attention_mask: torch.Tensor = None,
801
- input_tensor: torch.Tensor = None,
802
- cache_position: torch.Tensor = None,
803
- past_key_values: Cache = None,
804
- output_attentions: bool = False,
805
  ):
806
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
807
  using_static_cache = isinstance(past_key_values, StaticCache)
@@ -888,8 +913,12 @@ class DogeModel(DogePreTrainedModel):
888
  return causal_mask
889
 
890
 
 
 
 
891
  class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
892
  _tied_weights_keys = ["lm_head.weight"]
 
893
 
894
  def __init__(self, config: DogeConfig):
895
  super().__init__(config)
@@ -912,13 +941,13 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
912
 
913
  def set_output_embeddings(self, new_embeddings):
914
  self.lm_head = new_embeddings
 
 
 
915
 
916
  def set_decoder(self, decoder):
917
  self.model = decoder
918
 
919
- def get_decoder(self):
920
- return self.model
921
-
922
  @add_start_docstrings_to_model_forward(DOGE_INPUTS_DOCSTRING)
923
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
924
  def forward(
@@ -926,7 +955,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
926
  input_ids: torch.LongTensor = None,
927
  attention_mask: Optional[torch.Tensor] = None,
928
  position_ids: Optional[torch.LongTensor] = None,
929
- past_key_values: Optional[torch.Tensor] = None,
930
  inputs_embeds: Optional[torch.FloatTensor] = None,
931
  labels: Optional[torch.LongTensor] = None,
932
  use_cache: Optional[bool] = None,
@@ -935,7 +964,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
935
  return_dict: Optional[bool] = None,
936
  cache_position: Optional[torch.LongTensor] = None,
937
  num_logits_to_keep: int = 0,
938
- **loss_kwargs,
939
  ) -> Union[Tuple, CausalLMOutputWithPast]:
940
  r"""
941
  Args:
@@ -950,7 +979,23 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
950
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
951
 
952
  Returns:
953
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
954
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
955
  output_hidden_states = (
956
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -969,6 +1014,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
969
  output_hidden_states=output_hidden_states,
970
  return_dict=return_dict,
971
  cache_position=cache_position,
 
972
  )
973
 
974
  hidden_states = outputs[0]
@@ -978,7 +1024,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
978
 
979
  loss = None
980
  if labels is not None:
981
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **loss_kwargs)
982
 
983
  if not return_dict:
984
  output = (logits,) + outputs[1:]
@@ -993,18 +1039,98 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
993
  )
994
 
995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
996
  @add_start_docstrings(
997
  """
998
  The Doge Model transformer with a sequence classification head on top (linear layer).
999
 
1000
- [`DogeForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1001
- (e.g. GPT-2) do.
1002
 
1003
- Since it does classification on the last token, it requires to know the position of the last token. If a
1004
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1005
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1006
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1007
- each row of the batch).
1008
  """
1009
  )
1010
  class DogeForSequenceClassification(DogePreTrainedModel):
@@ -1041,9 +1167,9 @@ class DogeForSequenceClassification(DogePreTrainedModel):
1041
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1042
  r"""
1043
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1044
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1045
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1046
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1047
  """
1048
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1049
 
 
19
  """PyTorch Doge model."""
20
 
21
  import math
22
+ from typing import Callable, List, Optional, Tuple, Union
23
 
24
  import torch
25
  import torch.nn.functional as F
 
36
  )
37
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
38
  from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.processing_utils import Unpack
40
  from transformers.utils import (
41
+ LossKwargs,
42
  add_start_docstrings,
43
  add_start_docstrings_to_model_forward,
44
+ is_torch_greater_or_equal,
45
  logging,
46
  replace_return_docstrings,
47
  )
 
52
  except ImportError:
53
  einx_add = None
54
 
55
+ if is_torch_greater_or_equal("2.5"):
56
+ from torch.nn.attention.flex_attention import flex_attention
57
+
58
 
59
  logger = logging.get_logger(__name__)
60
 
 
85
  def __init__(self, hidden_size):
86
  super().__init__()
87
  self.weight = nn.Parameter(torch.ones(hidden_size))
88
+
89
  def forward(self, residual_states, hidden_states):
90
  return self.weight * residual_states + hidden_states
91
 
 
98
  super().__init__()
99
  self.rope_kwargs = {}
100
 
101
+ if config.rope_scaling is not None:
102
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
103
  else:
104
+ self.rope_type = "default"
105
  self.max_seq_len_cached = config.max_position_embeddings
106
  self.original_max_seq_len = config.max_position_embeddings
107
  self.base = config.rope_theta
 
139
  # core RoPE block
140
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
141
  position_ids_expanded = position_ids[:, None, :].float()
142
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
143
  device_type = x.device.type
144
  device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
145
  with torch.autocast(device_type=device_type, enabled=False):
 
148
  cos = emb.cos()
149
  sin = emb.sin()
150
 
151
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
152
  cos = cos * self.attention_scaling
153
  sin = sin * self.attention_scaling
154
 
 
176
  Deprecated and unused.
177
  unsqueeze_dim (`int`, *optional*, defaults to 1):
178
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
179
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k.
180
+ For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim].
181
+ Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k.
182
+ Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
 
183
  Returns:
184
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
185
  """
 
190
  return q_embed, k_embed
191
 
192
 
193
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
194
+ """
195
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
196
+ The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
197
+ """
198
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
199
+ if n_rep == 1:
200
+ return hidden_states
201
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
202
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
203
+
204
+
205
  class DogeDynamicMaskAttention(nn.Module):
206
  """Dynamic Mask Attention from 'Wonderful Matrices' paper."""
207
 
208
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
209
  super().__init__()
 
210
  self.config = config
211
  self.layer_idx = layer_idx
212
+ self.head_dim = config.hidden_size // config.num_attention_heads
213
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
214
+ self.scaling = self.head_dim ** -0.5
 
 
 
 
 
 
215
  self.attention_dropout = config.attention_dropout
216
+ self.dynamic_mask_ratio = config.dynamic_mask_ratio
217
+
218
+ self.ALL_ATTENTION_FUNCTIONS = {
219
+ "eager": self.eager_attention_forward,
220
+ "sdpa": self.sdpa_attention_forward,
221
+ "flex_attention": self.flex_attention_forward,
222
+ }
223
 
224
  # Q K V O projections
225
  self.q_proj = nn.Linear(
226
+ config.hidden_size,
227
+ config.num_attention_heads * self.head_dim,
228
+ bias=config.hidden_bias
229
  )
230
  self.k_proj = nn.Linear(
231
+ config.hidden_size,
232
+ config.num_key_value_heads * self.head_dim,
233
+ bias=config.hidden_bias
234
+ )
235
+ self.v_proj = nn.Linear(
236
+ config.hidden_size,
237
+ config.num_key_value_heads * self.head_dim,
238
+ bias=config.hidden_bias
239
  )
240
  # dynamic mask for the QK^T attention score matrix
241
  self.A = nn.Parameter(
242
+ torch.ones(config.num_attention_heads)
243
  )
244
  self.dt_proj = nn.Linear(
245
+ config.num_key_value_heads * self.head_dim,
246
+ config.num_attention_heads,
247
+ bias=config.hidden_bias
 
 
 
 
 
248
  )
249
  self.o_proj = nn.Linear(
250
+ config.num_attention_heads * self.head_dim,
251
+ config.hidden_size,
252
+ bias=config.hidden_bias
253
  )
254
 
255
  def forward(
256
  self,
257
  hidden_states: torch.Tensor,
258
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
259
  attention_mask: Optional[torch.Tensor] = None,
 
260
  past_key_value: Optional[Cache] = None,
261
  cache_position: Optional[torch.LongTensor] = None,
 
262
  **kwargs,
263
  ) -> Tuple[torch.Tensor, Optional[Cache]]:
264
+ input_shape = hidden_states.shape[:-1]
265
+ hidden_shape = (*input_shape, -1, self.head_dim)
 
 
 
266
 
267
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
268
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
269
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
 
 
 
 
 
 
270
 
271
  cos, sin = position_embeddings
272
  query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
 
276
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
277
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
278
 
279
+ # calculate dynamic mask from value_states
280
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1))
281
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
282
+ attn_mask = self.prepare_dynamic_mask(
283
+ hidden_states=hidden_states,
284
+ dynamic_mask=dynamic_mask,
285
+ dynamic_mask_ratio=self.dynamic_mask_ratio,
286
+ attention_mask=attention_mask,
287
+ )
 
 
 
 
 
288
 
289
+ attention_interface: Callable = self.eager_attention_forward
290
+ if self.config._attn_implementation != "eager":
291
+ attention_interface = self.ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
292
+
293
+ attn_output = attention_interface(
294
+ query_states,
295
+ key_states,
296
+ value_states,
297
+ attention_mask=attn_mask,
298
+ dropout=0.0 if not self.training else self.attention_dropout,
299
+ scaling=self.scaling,
300
+ **kwargs,
301
+ )
302
 
303
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
 
304
  attn_output = self.o_proj(attn_output)
305
+ return attn_output
306
 
307
+ def prepare_dynamic_mask(
 
 
 
 
 
308
  self,
309
  hidden_states: torch.Tensor,
310
+ dynamic_mask: torch.Tensor,
311
+ dynamic_mask_ratio: float = 0.0,
312
  attention_mask: Optional[torch.Tensor] = None,
313
+ ):
314
+ """
315
+ Combine `dynamic_mask` with `attention_mask` to generate the final `attn_mask`.
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
+ Args:
318
+ hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
319
+ dynamic_mask (`torch.Tensor`): dynamic mask of shape `(batch_size, num_heads, key_sequence_length)`.
320
+ dynamic_mask_ratio (`float`, *optional*): Ratio from 0.0 to 1.0 used to control the proportion of the dynamic mask filled with the minimum value.
321
+ attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
322
+ """
323
+ min_type = torch.finfo(hidden_states.dtype).min
324
+ attn_mask = dynamic_mask[:, :, None, :]
325
+ if 0.0 < dynamic_mask_ratio < 1.0:
326
+ num_dynamic_mask = int(attn_mask.shape[-1] * dynamic_mask_ratio)
327
+ if num_dynamic_mask > 0:
328
+ rate_value = torch.kthvalue(attn_mask, num_dynamic_mask, dim=-1, keepdim=True).values
329
+ attn_mask = attn_mask.masked_fill(attn_mask < rate_value, min_type)
330
+ if attention_mask is not None:
331
+ attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : hidden_states.shape[-2]] == min_type, min_type)
332
+ return attn_mask
333
+
334
+ def eager_attention_forward(
335
+ self,
336
+ query: torch.Tensor,
337
+ key: torch.Tensor,
338
+ value: torch.Tensor,
339
+ attention_mask: Optional[torch.Tensor],
340
+ scaling: float,
341
+ dropout: float = 0.0,
342
+ **kwargs,
343
+ ) -> torch.Tensor:
344
+ key_states = repeat_kv(key, self.num_key_value_groups)
345
+ value_states = repeat_kv(value, self.num_key_value_groups)
346
 
347
+ # compute attention scores matrix
348
+ attn_weights = torch.matmul(query, key_states.transpose(-1, -2)) * scaling
349
+ if attention_mask is not None:
350
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
351
+ attn_weights = attn_weights + causal_mask
352
+
353
+ # upcast attention scores to fp32
354
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
355
+ attn_weights = F.dropout(attn_weights, p=dropout, training=self.training)
356
 
357
+ # apply attention scores to value states
358
+ attn_output = torch.matmul(attn_weights, value_states)
359
+ attn_output = attn_output.transpose(1, 2).contiguous()
360
+ return attn_output
361
+
362
+ def sdpa_attention_forward(
363
+ self,
364
+ query: torch.Tensor,
365
+ key: torch.Tensor,
366
+ value: torch.Tensor,
367
+ attention_mask: Optional[torch.Tensor],
368
+ scaling: float,
369
+ dropout: float = 0.0,
370
+ **kwargs,
371
+ ) -> torch.Tensor:
372
+ causal_mask = attention_mask
373
  if attention_mask is not None:
374
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
 
 
 
375
 
376
+ # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
377
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
378
+ query = query.contiguous()
379
+ key = key.contiguous()
380
+ value = value.contiguous()
381
 
382
+ # NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
383
+ torch.backends.cuda.enable_cudnn_sdp(False)
384
  attn_output = F.scaled_dot_product_attention(
385
+ query,
386
+ key,
387
+ value,
388
  attn_mask=causal_mask,
389
+ dropout_p=dropout,
390
+ scale=scaling,
391
+ enable_gqa=True,
392
  )
 
393
  attn_output = attn_output.transpose(1, 2).contiguous()
394
+ return attn_output
395
+
396
+ def flex_attention_forward(
397
+ self,
398
+ query: torch.Tensor,
399
+ key: torch.Tensor,
400
+ value: torch.Tensor,
401
+ attention_mask: Optional[torch.Tensor],
402
+ scaling: float,
403
+ dropout: float = 0.0,
404
+ **kwargs,
405
+ ) -> torch.Tensor:
406
+ causal_mask = attention_mask
407
+ if attention_mask is not None:
408
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
409
 
410
+ # TODO: flex_attention: Captured buffers that require grad are not yet supported.
411
+ # NOTE: So we only use flex_attention in inference mode.
412
+ def mask_mod(score, batch, head, q_idx, kv_idx):
413
+ score = score + causal_mask[batch][head][q_idx][kv_idx]
414
+ return score
415
+
416
+ attn_output = flex_attention(
417
+ query,
418
+ key,
419
+ value,
420
+ score_mod=mask_mod,
421
+ scale=scaling,
422
+ enable_gqa=True,
423
+ )
424
+ attn_output = attn_output.transpose(1, 2).contiguous()
425
+ return attn_output
426
 
427
 
428
  class DogeMLP(nn.Module):
 
433
  self.intermediate_dim = config.intermediate_size
434
  self.act_fn = ACT2FN[config.hidden_act]
435
 
436
+ self.gate_proj = nn.Linear(self.hidden_dim, self.intermediate_dim, bias=config.hidden_bias)
437
+ self.up_proj = nn.Linear(self.hidden_dim, self.intermediate_dim, bias=config.hidden_bias)
438
+ self.down_proj = nn.Linear(self.intermediate_dim, self.hidden_dim, bias=config.hidden_bias)
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
  def forward(
441
  self,
 
455
  self.act_fn = ACT2FN[config.hidden_act]
456
 
457
  self.expert_retrieval_dim = config.expert_retrieval_size
458
+ self.num_cdmoe_experts = config.num_cdmoe_experts
459
+ self.num_cdmoe_heads = config.num_cdmoe_heads
460
+ self.num_cdmoe_experts_per_head = config.num_cdmoe_experts_per_head
461
+ self.num_keys = int(math.sqrt(self.num_cdmoe_experts))
462
 
463
  # queries and keys for retrieval experts
464
+ self.queries = nn.Linear(self.hidden_dim, self.num_cdmoe_heads * self.expert_retrieval_dim, bias=False)
465
+ self.keys = nn.Parameter(torch.zeros(self.num_cdmoe_heads, self.num_keys, 2, self.expert_retrieval_dim // 2))
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  # experts
468
+ self.down_embed = nn.Embedding(self.num_cdmoe_experts, self.hidden_dim)
469
+ self.up_embed = nn.Embedding(self.num_cdmoe_experts, self.hidden_dim)
 
 
 
 
 
 
 
470
 
471
  def forward(
472
  self,
 
477
 
478
  # get similarity with queries and keys
479
  queries = self.queries(hidden_states)
480
+ queries = queries.view(bsz, seq_len, 2, self.num_cdmoe_heads, -1).permute(2, 0, 1, 3, 4)
481
  sim = torch.einsum("p b t h n, h k p n -> p b t h k", queries, self.keys)
482
 
483
  # get experts with the highest similarity
484
+ (scores_x, scores_y), (indices_x, indices_y) = sim.topk(self.num_cdmoe_experts_per_head, dim=-1)
485
  if einx_add is not None:
486
  all_scores = einx_add("... i, ... j -> ... (i j)", scores_x, scores_y)
487
  all_indices = einx_add("... i, ... j -> ... (i j)", indices_x * self.num_keys, indices_y)
 
490
  all_scores = all_scores.view(*scores_x.shape[:-1], -1)
491
  all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
492
  all_indices = all_indices.view(*indices_x.shape[:-1], -1)
493
+ scores, pk_indices = all_scores.topk(self.num_cdmoe_experts_per_head, dim=-1)
494
  indices = all_indices.gather(-1, pk_indices)
495
  down_embed = self.down_embed(indices)
496
  up_embed = self.up_embed(indices)
 
509
  super().__init__()
510
  self.hidden_dropout = config.hidden_dropout
511
 
512
+ self.pre_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
513
+ self.self_attn = DogeDynamicMaskAttention(config=config, layer_idx=layer_idx)
514
+ self.pre_residual = Residual(config.hidden_size)
515
 
516
+ self.post_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
517
  self.feed_forward = DogeMLP(config) if config.is_moe == False else DogeCDMoE(config)
518
+ self.post_residual = Residual(config.hidden_size)
519
 
520
  def forward(
521
  self,
 
526
  output_attentions: Optional[bool] = False,
527
  use_cache: Optional[bool] = False,
528
  cache_position: Optional[torch.LongTensor] = None,
529
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
530
  **kwargs,
531
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
  # sequence transformation
534
  residual = hidden_states
535
+ hidden_states = self.pre_layernorm(hidden_states)
536
+ hidden_states = self.self_attn(
537
  hidden_states=hidden_states,
538
  attention_mask=attention_mask,
539
  position_ids=position_ids,
 
544
  )
545
  self_attn_weights = None
546
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
547
+ hidden_states = self.pre_residual(residual, hidden_states)
548
 
549
  # state transformation
550
  residual = hidden_states
551
+ hidden_states = self.post_layernorm(hidden_states)
552
  hidden_states = self.feed_forward(hidden_states)
553
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
554
+ hidden_states = self.post_residual(residual, hidden_states)
555
 
556
  outputs = (hidden_states,)
 
557
  if output_attentions:
558
  outputs += (self_attn_weights,)
559
 
 
 
 
560
  return outputs
561
 
562
 
563
+ DOGE_START_DOCSTRING = r"""
564
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
565
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
566
+ etc.)
567
+
568
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
569
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
570
+ and behavior.
571
+
572
+ Parameters:
573
+ config ([`DogeConfig`]):
574
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
575
+ load the weights associated with the model, only the configuration. Check out the
576
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
577
+ """
578
+ @add_start_docstrings(
579
+ "The bare Doge Model outputting raw hidden-states without any specific head on top.",
580
+ DOGE_START_DOCSTRING,
581
+ )
582
  class DogePreTrainedModel(PreTrainedModel):
583
  config_class = DogeConfig
584
  base_model_prefix = "model"
 
586
  _no_split_modules = ["DogeDecoderLayer"]
587
  _skip_keys_device_placement = ["past_key_values"]
588
  _supports_sdpa = True
589
+ _supports_flex_attn = True
590
  _supports_cache_class = True
591
  _supports_quantized_cache = True
592
  _supports_static_cache = True
 
678
  """
679
 
680
 
681
+ @add_start_docstrings(
682
+ "The bare Doge Model outputting raw hidden-states without any specific head on top.",
683
+ DOGE_START_DOCSTRING,
684
+ )
685
  class DogeModel(DogePreTrainedModel):
686
+ """
687
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DogeDecoderLayer`]
688
+
689
+ Args:
690
+ config: DogeConfig
691
+ """
692
+
693
  def __init__(self, config: DogeConfig):
694
  super().__init__(config)
695
  self.config = config
 
726
  output_hidden_states: Optional[bool] = None,
727
  return_dict: Optional[bool] = None,
728
  cache_position: Optional[torch.LongTensor] = None,
729
+ **kwargs,
730
  ) -> Union[Tuple, BaseModelOutputWithPast]:
731
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
732
  output_hidden_states = (
 
747
  if inputs_embeds is None:
748
  inputs_embeds = self.word_embed(input_ids)
749
 
750
+ if use_cache and past_key_values is None:
751
+ past_key_values = DynamicCache()
 
 
 
 
 
 
 
 
 
 
 
752
 
753
  if cache_position is None:
754
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
755
  cache_position = torch.arange(
756
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
757
  )
758
+
759
  if position_ids is None:
760
  position_ids = cache_position.unsqueeze(0)
761
 
762
  causal_mask = self._update_causal_mask(
763
  attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
764
  )
765
+
766
  hidden_states = inputs_embeds
767
 
768
  # create position embeddings to be shared across the decoder layers
 
771
  # decoder layers
772
  all_hidden_states = () if output_hidden_states else None
773
  all_self_attns = () if output_attentions else None
 
774
 
775
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
776
  if output_hidden_states:
777
  all_hidden_states += (hidden_states,)
778
 
 
798
  use_cache=use_cache,
799
  cache_position=cache_position,
800
  position_embeddings=position_embeddings,
801
+ **kwargs,
802
  )
803
 
804
  hidden_states = layer_outputs[0]
805
 
 
 
 
806
  if output_attentions:
807
  all_self_attns += (layer_outputs[1],)
808
 
 
812
  if output_hidden_states:
813
  all_hidden_states += (hidden_states,)
814
 
815
+ output = BaseModelOutputWithPast(
 
 
 
 
 
 
 
816
  last_hidden_state=hidden_states,
817
+ past_key_values=past_key_values if use_cache else None,
818
  hidden_states=all_hidden_states,
819
  attentions=all_self_attns,
820
  )
821
+ return output if return_dict else output.to_tuple()
822
 
823
  def _update_causal_mask(
824
  self,
825
+ attention_mask: torch.Tensor,
826
+ input_tensor: torch.Tensor,
827
+ cache_position: torch.Tensor,
828
+ past_key_values: Cache,
829
+ output_attentions: bool,
830
  ):
831
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
832
  using_static_cache = isinstance(past_key_values, StaticCache)
 
913
  return causal_mask
914
 
915
 
916
+ class KwargsForCausalLM(LossKwargs): ...
917
+
918
+
919
  class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
920
  _tied_weights_keys = ["lm_head.weight"]
921
+ _tp_plan = {"lm_head": "colwise_rep"}
922
 
923
  def __init__(self, config: DogeConfig):
924
  super().__init__(config)
 
941
 
942
  def set_output_embeddings(self, new_embeddings):
943
  self.lm_head = new_embeddings
944
+
945
+ def get_decoder(self):
946
+ return self.model
947
 
948
  def set_decoder(self, decoder):
949
  self.model = decoder
950
 
 
 
 
951
  @add_start_docstrings_to_model_forward(DOGE_INPUTS_DOCSTRING)
952
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
953
  def forward(
 
955
  input_ids: torch.LongTensor = None,
956
  attention_mask: Optional[torch.Tensor] = None,
957
  position_ids: Optional[torch.LongTensor] = None,
958
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
959
  inputs_embeds: Optional[torch.FloatTensor] = None,
960
  labels: Optional[torch.LongTensor] = None,
961
  use_cache: Optional[bool] = None,
 
964
  return_dict: Optional[bool] = None,
965
  cache_position: Optional[torch.LongTensor] = None,
966
  num_logits_to_keep: int = 0,
967
+ **kwargs: Unpack[KwargsForCausalLM],
968
  ) -> Union[Tuple, CausalLMOutputWithPast]:
969
  r"""
970
  Args:
 
979
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
980
 
981
  Returns:
982
+
983
+ Example:
984
+
985
+ ```python
986
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
987
+
988
+ >>> model = AutoModelForCausalLM.from_pretrained("JingzeShi/Doge-20M-Instruct")
989
+ >>> tokenizer = AutoTokenizer.from_pretrained("JingzeShi/Doge-20M-Instruct")
990
+
991
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
992
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
993
+
994
+ >>> # Generate
995
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
996
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
997
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
998
+ ```"""
999
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1000
  output_hidden_states = (
1001
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1014
  output_hidden_states=output_hidden_states,
1015
  return_dict=return_dict,
1016
  cache_position=cache_position,
1017
+ **kwargs,
1018
  )
1019
 
1020
  hidden_states = outputs[0]
 
1024
 
1025
  loss = None
1026
  if labels is not None:
1027
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
1028
 
1029
  if not return_dict:
1030
  output = (logits,) + outputs[1:]
 
1039
  )
1040
 
1041
 
1042
+ class DogePatchEmbedding(nn.Module):
1043
+ """
1044
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` of shape `(batch_size, seq_len, hidden_size)` to be consumed by a Transformer.
1045
+ """
1046
+
1047
+ def __init__(self, config: DogeConfig):
1048
+ super().__init__()
1049
+
1050
+ self.num_channels = config.num_channels
1051
+ self.patch_size = config.patch_size
1052
+ self.hidden_dim = config.hidden_size
1053
+
1054
+ self.sequence_proj = nn.Conv2d(self.num_channels, self.hidden_dim, kernel_size=self.patch_size, stride=self.patch_size)
1055
+ self.state_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
1056
+
1057
+ def forward(
1058
+ self,
1059
+ pixel_values: torch.Tensor,
1060
+ ) -> torch.Tensor:
1061
+ image_embedding = self.sequence_proj(pixel_values).flatten(2).transpose(1, 2)
1062
+ image_embedding = self.state_proj(image_embedding)
1063
+ return image_embedding
1064
+
1065
+
1066
+ class DogeForCausalVLM(DogeForCausalLM):
1067
+ _tied_weights_keys = ["lm_head.weight"]
1068
+
1069
+ def __init__(self, config: DogeConfig):
1070
+ super().__init__(config)
1071
+ self.config = config
1072
+ self.pixel_embed = DogePatchEmbedding(config)
1073
+
1074
+ # Initialize weights and apply final processing
1075
+ self.post_init()
1076
+
1077
+ def forward(
1078
+ self,
1079
+ input_ids: torch.LongTensor = None,
1080
+ pixel_values: torch.FloatTensor = None,
1081
+ attention_mask: Optional[torch.Tensor] = None,
1082
+ position_ids: Optional[torch.LongTensor] = None,
1083
+ past_key_values: Optional[torch.Tensor] = None,
1084
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1085
+ labels: Optional[torch.LongTensor] = None,
1086
+ use_cache: Optional[bool] = None,
1087
+ output_attentions: Optional[bool] = None,
1088
+ output_hidden_states: Optional[bool] = None,
1089
+ return_dict: Optional[bool] = None,
1090
+ cache_position: Optional[torch.LongTensor] = None,
1091
+ num_logits_to_keep: int = 0,
1092
+ **loss_kwargs,
1093
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1094
+ # TODO: @wubingheng111: refer to Llava for implementating the forward method
1095
+ ...
1096
+
1097
+ def prepare_inputs_for_generation(
1098
+ self,
1099
+ input_ids=None,
1100
+ pixel_values=None,
1101
+ past_key_values=None,
1102
+ input_embeds=None,
1103
+ attention_mask=None,
1104
+ cache_position=None,
1105
+ num_logits_to_keep=None,
1106
+ **kwargs,
1107
+ ):
1108
+ model_inputs = self.model.prepare_inputs_for_generation(
1109
+ input_ids,
1110
+ past_key_values=past_key_values,
1111
+ inputs_embeds=input_embeds,
1112
+ attention_mask=attention_mask,
1113
+ cache_position=cache_position,
1114
+ num_logits_to_keep=num_logits_to_keep,
1115
+ **kwargs,
1116
+ )
1117
+
1118
+ if cache_position[0] == 0:
1119
+ model_inputs["pixel_values"] = pixel_values
1120
+
1121
+ return model_inputs
1122
+
1123
+
1124
  @add_start_docstrings(
1125
  """
1126
  The Doge Model transformer with a sequence classification head on top (linear layer).
1127
 
1128
+ [`DogeForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do.
 
1129
 
1130
+ Since it does classification on the last token, it requires to know the position of the last token.
1131
+ If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row.
1132
+ If no `pad_token_id` is defined, it simply takes the last value in each row of the batch.
1133
+ Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch).
 
1134
  """
1135
  )
1136
  class DogeForSequenceClassification(DogePreTrainedModel):
 
1167
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1168
  r"""
1169
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1170
+ Labels for computing the sequence classification/regression loss.
1171
+ Indices should be in `[0, ..., config.num_labels - 1]`.
1172
+ If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1173
  """
1174
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1175