quintic commited on
Commit
072784c
·
1 Parent(s): 55258da

initial commit

Browse files
Files changed (7) hide show
  1. .gitignore +0 -0
  2. README.md +0 -0
  3. __init__.py +0 -0
  4. config.json +22 -0
  5. configuration_aria.py +50 -0
  6. model.safetensors +3 -0
  7. modeling_aria.py +565 -0
.gitignore ADDED
File without changes
README.md ADDED
File without changes
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "AriaForCausalLM"
4
+ ],
5
+ "bos_token_id": 0,
6
+ "eos_token_id": 1,
7
+ "hidden_size": 1536,
8
+ "intermediate_size": 6144,
9
+ "max_position_embeddings": 8192,
10
+ "model_type": "aria",
11
+ "num_attention_heads": 64,
12
+ "num_hidden_layers": 16,
13
+ "torch_dtype": "float16",
14
+ "transformers_version": "4.45.0",
15
+ "use_cache": true,
16
+ "vocab_size": 17731,
17
+ "auto_map": {
18
+ "AutoConfig": "configuration_aria.AriaConfig",
19
+ "AutoModel": "modeling_aria.AriaModel",
20
+ "AutoModelForCausalLM": "modeling_aria.AriaForCausalLM"
21
+ }
22
+ }
configuration_aria.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class AriaConfig(PretrainedConfig):
5
+ model_type = "aria"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size: int = 17731,
11
+ hidden_size: int = 1536,
12
+ num_hidden_layers: int = 16,
13
+ num_attention_heads: int = 64,
14
+ intermediate_size: int = 6144,
15
+ max_position_embeddings: int = 8192,
16
+ use_cache: bool = True,
17
+ bos_token_id: int = 0,
18
+ eos_token_id: int = 1,
19
+ tie_word_embeddings: bool = False,
20
+ output_attentions: bool = False,
21
+ output_hidden_states: bool = False,
22
+ return_dict: bool = False,
23
+ **kwargs,
24
+ ):
25
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
26
+ self.vocab_size = vocab_size
27
+ self.hidden_size = hidden_size
28
+ self.num_hidden_layers = num_hidden_layers
29
+ self.num_attention_heads = num_attention_heads
30
+ self.intermediate_size = intermediate_size
31
+ self.max_position_embeddings = max_position_embeddings
32
+ self.use_cache = use_cache
33
+ self.tie_word_embeddings = tie_word_embeddings
34
+ self.output_attentions = output_attentions
35
+ self.output_hidden_states = output_hidden_states
36
+ self.return_dict = return_dict
37
+
38
+ if self.intermediate_size % self.hidden_size != 0:
39
+ raise ValueError("The intermediate size needs to be divisible by hidden size.")
40
+
41
+ if self.hidden_size % self.num_attention_heads != 0:
42
+ raise ValueError("The hidden size needs to be divisible by the number of attention heads.")
43
+
44
+ @property
45
+ def ff_mult(self):
46
+ return self.intermediate_size // self.hidden_size
47
+
48
+
49
+
50
+ __all__ = ["AriaConfig"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e592f31b380742f5426c0c80c8cac65efc97c6981f3b7b6b3eee193793d0116d
3
+ size 2634219792
modeling_aria.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is lightly adapted from https://github.com/EleutherAI/aria/blob/main/aria/model.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Union, Tuple, List
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+
9
+ from torch import nn as nn
10
+ from torch.nn import functional as F, CrossEntropyLoss
11
+
12
+ from transformers import Cache, DynamicCache, StaticCache
13
+ from transformers.utils import logging
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
17
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
18
+
19
+ from .configuration_aria import AriaConfig
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class AriaPreTrainedModel(PreTrainedModel):
26
+ config_class = AriaConfig
27
+ base_model_prefix = "aria"
28
+ supports_gradient_checkpointing = True
29
+ _no_split_modules = ["AriaBlock"]
30
+ _skip_keys_device_placement = "past_key_values"
31
+ _supports_flash_attn_2 = False
32
+ _supports_cache_class = True
33
+ _supports_quantized_cache = True
34
+ _supports_static_cache = True
35
+ _supports_sdpa = True
36
+ _supports_flex_attn = False
37
+
38
+ def _init_weights(self, module):
39
+ if isinstance(module, nn.Linear):
40
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
41
+ if module.bias is not None:
42
+ module.bias.data.zero_()
43
+ elif isinstance(module, nn.Embedding):
44
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
45
+ if module.padding_idx is not None:
46
+ module.weight.data[module.padding_idx].zero_()
47
+ elif isinstance(module, nn.LayerNorm):
48
+ module.bias.data.zero_()
49
+ module.weight.data.fill_(1.0)
50
+
51
+
52
+ class AriaBlock(nn.Module):
53
+ def __init__(self, model_config: AriaConfig, layer_idx: int):
54
+ super().__init__()
55
+
56
+ self.drop_p = 0.0
57
+ self.n_heads = model_config.num_attention_heads
58
+ self.d_model = model_config.hidden_size
59
+ self.d_head = model_config.hidden_size // model_config.num_attention_heads
60
+ self.max_seq_len = model_config.max_position_embeddings
61
+ self.layer_idx = layer_idx
62
+
63
+ # Attention
64
+ self.mixed_qkv = nn.Linear(
65
+ in_features=self.d_model,
66
+ out_features=3 * self.d_model,
67
+ bias=False,
68
+ )
69
+ self.att_proj_linear = nn.Linear(
70
+ in_features=self.d_model,
71
+ out_features=self.d_model,
72
+ bias=False,
73
+ )
74
+
75
+ # FF Layer
76
+ self.ff_gate_proj = nn.Linear(
77
+ in_features=self.d_model,
78
+ out_features=self.d_model * model_config.ff_mult,
79
+ bias=False,
80
+ )
81
+ self.ff_up_proj = nn.Linear(
82
+ in_features=self.d_model,
83
+ out_features=self.d_model * model_config.ff_mult,
84
+ bias=False,
85
+ )
86
+ self.ff_down_proj = nn.Linear(
87
+ in_features=self.d_model * model_config.ff_mult,
88
+ out_features=self.d_model,
89
+ bias=False,
90
+ )
91
+
92
+ # Pre layer norms
93
+ self.norm1 = nn.LayerNorm(self.d_model)
94
+ self.norm2 = nn.LayerNorm(self.d_model)
95
+
96
+ def forward(
97
+ self,
98
+ x: torch.Tensor,
99
+ attention_mask: torch.Tensor,
100
+ freqs_cis: torch.Tensor,
101
+ position_ids: Optional[torch.Tensor] = None,
102
+ past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
103
+ use_cache: Optional[bool] = None,
104
+ output_attentions: Optional[bool] = None,
105
+ output_hidden_states: Optional[bool] = None,
106
+ return_dict: Optional[bool] = None,
107
+ cache_position: Optional[torch.Tensor] = None
108
+ ):
109
+ attn_output, attn_weights, present = self._att_block(self.norm1(x), attention_mask, freqs_cis,
110
+ past_key_values=past_key_values,
111
+ use_cache=use_cache,
112
+ output_attentions=output_attentions,
113
+ cache_position=cache_position)
114
+
115
+ x = x + attn_output
116
+ x = x + self._ff_block(self.norm2(x))
117
+
118
+ outputs = (x, present)
119
+ if use_cache:
120
+ outputs = (x, present, attn_weights)
121
+ else:
122
+ outputs = (x, attn_weights)
123
+
124
+ return outputs
125
+
126
+ def _att_block(
127
+ self,
128
+ x: torch.Tensor,
129
+ attention_mask: torch.Tensor,
130
+ freqs_cis: torch.Tensor,
131
+ past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
132
+ use_cache: Optional[bool] = None,
133
+ output_attentions: Optional[bool] = None,
134
+ cache_position: Optional[torch.Tensor] = None
135
+ ):
136
+ batch_size, seq_len, _ = x.shape
137
+ mixed_qkv = self.mixed_qkv(x)
138
+ xq, xk, xv = mixed_qkv.chunk(3, -1)
139
+
140
+ # Reshape for rotary embeddings
141
+ # Need contiguous for q, k since in-place RoPE cannot be applied on a view
142
+ xq = xq.reshape(
143
+ batch_size, seq_len, self.n_heads, self.d_head
144
+ ).contiguous()
145
+ xk = xk.reshape(
146
+ batch_size, seq_len, self.n_heads, self.d_head
147
+ ).contiguous()
148
+ xv = xv.view(batch_size, seq_len, self.n_heads, self.d_head)
149
+
150
+ # apply_rotary_post_emb expects: (b_sz, s_len, n_head, d_head)
151
+ xq = apply_rotary_emb(xq, freqs_cis)
152
+ xk = apply_rotary_emb(xk, freqs_cis)
153
+ xq, xk, xv = map(lambda t: t.transpose(1, 2), (xq, xk, xv))
154
+
155
+ if past_key_values is not None:
156
+ cache_kwargs = {
157
+ #"sin": sin,
158
+ #"cos": cos,
159
+ #"partial_rotation_size": self.rotary_ndims,
160
+ "cache_position": cache_position,
161
+ }
162
+ xk, xv = past_key_values.update(xk, xv, self.layer_idx, cache_kwargs)
163
+ # scaled_dot_product_attention expects: (b_sz, n_head, s_len, d_head)
164
+ att = F.scaled_dot_product_attention(
165
+ query=xq,
166
+ key=xk,
167
+ value=xv,
168
+ attn_mask=attention_mask,
169
+ is_causal=True,
170
+ )
171
+
172
+ # Reshape for out: (b_sz, s_len, n_head, d_head)
173
+ out = att.transpose(1, 2).contiguous()
174
+ out = out.view(batch_size, seq_len, self.n_heads * self.d_head)
175
+
176
+ if not output_attentions:
177
+ att = None
178
+
179
+ return self.att_proj_linear(out), att, past_key_values
180
+
181
+ def _ff_block(self, x: torch.Tensor):
182
+
183
+ return self.ff_down_proj(
184
+ F.silu(self.ff_gate_proj(x)) * self.ff_up_proj(x)
185
+ )
186
+
187
+
188
+ class AriaModel(AriaPreTrainedModel):
189
+ """Transformer decoder with no language model head.
190
+
191
+ Args:
192
+ model_config (ModelConfig): Model config settings.
193
+ """
194
+
195
+ def __init__(self, model_config: AriaConfig):
196
+ super().__init__(model_config)
197
+ self.model_config = model_config
198
+ self.freqs_cis = None
199
+
200
+ self.tok_embeddings = nn.Embedding(
201
+ num_embeddings=model_config.vocab_size,
202
+ embedding_dim=model_config.hidden_size,
203
+ )
204
+
205
+ self.out_layer_norm = nn.LayerNorm(model_config.hidden_size)
206
+ self.encode_layers = nn.ModuleList()
207
+ for i in range(model_config.num_hidden_layers):
208
+ self.encode_layers.append(AriaBlock(model_config, i))
209
+
210
+ self.gradient_checkpointing = False
211
+ self.post_init()
212
+
213
+ def forward(
214
+ self,
215
+ input_ids: Optional[torch.Tensor] = None,
216
+ attention_mask: Optional[torch.Tensor] = None,
217
+ position_ids: Optional[torch.Tensor] = None,
218
+ inputs_embeds: Optional[torch.Tensor] = None,
219
+ past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
220
+ use_cache: Optional[bool] = None,
221
+ output_attentions: Optional[bool] = None,
222
+ output_hidden_states: Optional[bool] = None,
223
+ return_dict: Optional[bool] = None,
224
+ cache_position: Optional[torch.Tensor] = None,
225
+ ):
226
+ """Forward pass of Transformer.
227
+
228
+ Args:
229
+ src (torch.tensor): Input to encoder block, of shape (batch_size,
230
+ seq_len, d_model).
231
+ attn_mask (Optional[torch.tensor]): Attention mask of shape
232
+ (batch_size, seq_len). Defaults to None.
233
+ past_kv (Optional[list[KVCache]]): a list of kv caches. The list index
234
+ corresponds to the layer index.
235
+
236
+ Returns:
237
+ torch.tensor: Model outputs with shape (batch_size, seq_len,
238
+ d_model).
239
+ """
240
+ output_attentions = output_attentions if output_attentions is not None else self.model_config.output_attentions
241
+ output_hidden_states = (
242
+ output_hidden_states if output_hidden_states is not None else self.model_config.output_hidden_states
243
+ )
244
+ return_dict = return_dict if return_dict is not None else self.model_config.use_return_dict
245
+ use_cache = use_cache if use_cache is not None else self.model_config.use_cache
246
+
247
+ if (input_ids is None) ^ (inputs_embeds is not None):
248
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
249
+
250
+ if self.gradient_checkpointing and self.training:
251
+ if use_cache:
252
+ logger.warning_once(
253
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
254
+ )
255
+ use_cache = False
256
+
257
+ if inputs_embeds is None:
258
+ inputs_embeds = self.tok_embeddings(input_ids)
259
+
260
+ return_legacy_cache = False
261
+ if use_cache and not isinstance(past_key_values, Cache):
262
+ return_legacy_cache = True
263
+ if past_key_values is None:
264
+ past_key_values = DynamicCache()
265
+ else:
266
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
267
+ logger.warning_once(
268
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
269
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
270
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
271
+ )
272
+
273
+ seq_length = inputs_embeds.shape[1]
274
+ if cache_position is None:
275
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
276
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
277
+
278
+ if position_ids is None:
279
+ position_ids = cache_position.unsqueeze(0)
280
+ hidden_states = inputs_embeds
281
+
282
+ causal_mask = self._update_causal_mask(
283
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
284
+ )
285
+
286
+ if self.freqs_cis is None:
287
+ self.freqs_cis = precompute_freqs_cis(
288
+ seq_len=self.model_config.max_position_embeddings,
289
+ n_elem=self.model_config.hidden_size // self.model_config.num_attention_heads,
290
+ base=500000,
291
+ dtype=hidden_states.dtype,
292
+ ).to(input_ids.device)
293
+ freqs_cis = self.freqs_cis[: input_ids.shape[1]]
294
+
295
+ kwargs = {
296
+ "position_ids": position_ids,
297
+ "past_key_values": past_key_values,
298
+ "use_cache": use_cache,
299
+ "output_attentions": output_attentions,
300
+ "output_hidden_states": output_hidden_states,
301
+ "return_dict": return_dict,
302
+ "cache_position": cache_position,
303
+ }
304
+ next_decoder_cache = None
305
+ if self.gradient_checkpointing:
306
+ for layer in self.encode_layers:
307
+
308
+ def create_custom_forward(module):
309
+ def custom_forward(*args):
310
+ return module(*args)[0]
311
+
312
+ return custom_forward
313
+
314
+ hidden_states = torch.utils.checkpoint.checkpoint(
315
+ create_custom_forward(layer),
316
+ hidden_states,
317
+ causal_mask,
318
+ freqs_cis,
319
+ **kwargs,
320
+ preserve_rng_state=True,
321
+ use_reentrant=True,
322
+ )
323
+ else:
324
+ all_attentions = () if output_attentions else None
325
+ all_hidden_states = () if output_hidden_states else None
326
+ for layer in self.encode_layers:
327
+ if output_hidden_states:
328
+ all_hidden_states = all_hidden_states + (hidden_states,)
329
+ outputs = layer(hidden_states, causal_mask, freqs_cis=freqs_cis, **kwargs)
330
+ hidden_states = outputs[0]
331
+ if use_cache is True:
332
+ next_decoder_cache = outputs[1]
333
+ if output_attentions:
334
+ all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
335
+ if output_hidden_states:
336
+ all_hidden_states = all_hidden_states + (hidden_states,)
337
+
338
+ hidden_states = self.out_layer_norm(hidden_states)
339
+ next_cache = next_decoder_cache if use_cache else None
340
+
341
+ if return_legacy_cache:
342
+ next_cache = next_cache.to_legacy_cache()
343
+
344
+ if not return_dict:
345
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
346
+
347
+ return BaseModelOutputWithPast(
348
+ last_hidden_state=hidden_states,
349
+ past_key_values=next_cache,
350
+ hidden_states=all_hidden_states,
351
+ attentions=all_attentions,
352
+ )
353
+
354
+ def _update_causal_mask(
355
+ self,
356
+ attention_mask: torch.Tensor,
357
+ input_tensor: torch.Tensor,
358
+ cache_position: torch.Tensor,
359
+ past_key_values: Cache,
360
+ output_attentions: bool,
361
+ ):
362
+ if self.model_config._attn_implementation == "flash_attention_2":
363
+ if attention_mask is not None and (attention_mask == 0.0).any():
364
+ return attention_mask
365
+ return None
366
+
367
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
368
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
369
+ # to infer the attention mask.
370
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
371
+ using_static_cache = isinstance(past_key_values, StaticCache)
372
+
373
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
374
+ if self.model_config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
375
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
376
+ attention_mask,
377
+ inputs_embeds=input_tensor,
378
+ past_key_values_length=past_seen_tokens,
379
+ is_training=self.training,
380
+ ):
381
+ return None
382
+
383
+ dtype, device = input_tensor.dtype, input_tensor.device
384
+ sequence_length = input_tensor.shape[1]
385
+ if using_static_cache:
386
+ target_length = past_key_values.get_max_cache_shape()
387
+ else:
388
+ target_length = (
389
+ attention_mask.shape[-1]
390
+ if isinstance(attention_mask, torch.Tensor)
391
+ else past_seen_tokens + sequence_length + 1
392
+ )
393
+
394
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
395
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
396
+ attention_mask,
397
+ sequence_length=sequence_length,
398
+ target_length=target_length,
399
+ dtype=dtype,
400
+ device=device,
401
+ cache_position=cache_position,
402
+ batch_size=input_tensor.shape[0],
403
+ )
404
+
405
+ if (
406
+ self.model_config._attn_implementation == "sdpa"
407
+ and attention_mask is not None
408
+ and attention_mask.device.type == "cuda"
409
+ and not output_attentions
410
+ ):
411
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
412
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
413
+ # Details: https://github.com/pytorch/pytorch/issues/110213
414
+ min_dtype = torch.finfo(dtype).min
415
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
416
+
417
+ return causal_mask
418
+
419
+ @staticmethod
420
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
421
+ def _prepare_4d_causal_attention_mask_with_cache_position(
422
+ attention_mask: torch.Tensor,
423
+ sequence_length: int,
424
+ target_length: int,
425
+ dtype: torch.dtype,
426
+ device: torch.device,
427
+ cache_position: torch.Tensor,
428
+ batch_size: int,
429
+ **kwargs,
430
+ ):
431
+ if attention_mask is not None and attention_mask.dim() == 4:
432
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
433
+ causal_mask = attention_mask
434
+ else:
435
+ min_dtype = torch.finfo(dtype).min
436
+ causal_mask = torch.full(
437
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
438
+ )
439
+ if sequence_length != 1:
440
+ causal_mask = torch.triu(causal_mask, diagonal=1)
441
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
442
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
443
+ if attention_mask is not None:
444
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
445
+ mask_length = attention_mask.shape[-1]
446
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
447
+ padding_mask = padding_mask == 0
448
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
449
+ padding_mask, min_dtype
450
+ )
451
+
452
+ return causal_mask
453
+
454
+
455
+ class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
456
+ """Transformer decoder with head for language modelling.
457
+
458
+ Args:
459
+ model_config (ModelConfig): Model config settings.
460
+ """
461
+
462
+ def __init__(self, model_config: AriaConfig):
463
+ super().__init__(model_config)
464
+ self.model_config = model_config
465
+ self.max_seq_len = model_config.max_position_embeddings
466
+ self.model = AriaModel(model_config)
467
+ self.lm_head = nn.Linear(
468
+ model_config.hidden_size, model_config.vocab_size, bias=False
469
+ )
470
+
471
+ def forward(
472
+ self,
473
+ input_ids: Optional[torch.Tensor] = None,
474
+ attention_mask: Optional[torch.Tensor] = None,
475
+ position_ids: Optional[torch.Tensor] = None,
476
+ inputs_embeds: Optional[torch.Tensor] = None,
477
+ past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
478
+ labels: Optional[torch.Tensor] = None,
479
+ use_cache: Optional[bool] = None,
480
+ output_attentions: Optional[bool] = None,
481
+ output_hidden_states: Optional[bool] = None,
482
+ return_dict: Optional[bool] = None,
483
+ cache_position: Optional[torch.Tensor] = None,
484
+ ):
485
+ """Forward pass of Transformer decoder with LM head."""
486
+ return_dict = return_dict if return_dict is not None else self.model_config.use_return_dict
487
+ outputs = self.model(
488
+ input_ids,
489
+ attention_mask=attention_mask,
490
+ position_ids=position_ids,
491
+ inputs_embeds=inputs_embeds,
492
+ past_key_values=past_key_values,
493
+ use_cache=use_cache,
494
+ output_attentions=output_attentions,
495
+ output_hidden_states=output_hidden_states,
496
+ return_dict=return_dict,
497
+ cache_position=cache_position,
498
+ )
499
+ hidden = outputs[0]
500
+ lm_logits = self.lm_head(hidden)
501
+
502
+ lm_loss = None
503
+ if labels is not None:
504
+ # move labels to correct device to enable model parallelism
505
+ labels = labels.to(lm_logits.device)
506
+ # we are doing next-token prediction; shift prediction scores and input ids by one
507
+ shift_logits = lm_logits[:, :-1, :].contiguous()
508
+ labels = labels[:, 1:].contiguous()
509
+ loss_fct = CrossEntropyLoss()
510
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
511
+
512
+ if not return_dict:
513
+ output = (lm_logits,) + outputs[1:]
514
+ return ((lm_loss,) + output) if lm_loss is not None else output
515
+
516
+ return CausalLMOutputWithPast(
517
+ loss=lm_loss,
518
+ logits=lm_logits,
519
+ past_key_values=outputs.past_key_values,
520
+ hidden_states=outputs.hidden_states,
521
+ attentions=outputs.attentions,
522
+ )
523
+
524
+
525
+ def precompute_freqs_cis(
526
+ seq_len: int,
527
+ n_elem: int,
528
+ base: int = 500000,
529
+ dtype: torch.dtype = torch.bfloat16,
530
+ ):
531
+ freqs = 1.0 / (
532
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
533
+ )
534
+ t = torch.arange(seq_len, device=freqs.device)
535
+ freqs = torch.outer(t, freqs)
536
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
537
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
538
+
539
+ return cache.to(dtype=dtype)
540
+
541
+
542
+ @torch.jit.script
543
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
544
+ """
545
+ In-place RoPE. Credits to Katherine Crowson:
546
+ x shape (b_sz, s_len, n_head, d_head).
547
+ cos, sin shape (s_len, d_head // 2).
548
+ """
549
+
550
+ d = x.shape[-1] // 2
551
+ cos = freqs_cis[..., 0][None, :, None]
552
+ sin = freqs_cis[..., 1][None, :, None]
553
+ x1, x2 = x[..., :d], x[..., d : d * 2]
554
+ tmp = x1.clone()
555
+ x1.mul_(cos).addcmul_(x2, sin, value=-1)
556
+ x2.mul_(cos).addcmul_(tmp, sin, value=1)
557
+ return x
558
+
559
+
560
+ __all__ = [
561
+ "AriaForCausalLM",
562
+ "AriaBlock",
563
+ "AriaModel",
564
+ "AriaPreTrainedModel",
565
+ ]