quintic commited on
Commit
fd1489d
·
1 Parent(s): aafacc4

add tokenizer; reformat

Browse files
configuration_aria.py CHANGED
@@ -36,15 +36,18 @@ class AriaConfig(PretrainedConfig):
36
  self.return_dict = return_dict
37
 
38
  if self.intermediate_size % self.hidden_size != 0:
39
- raise ValueError("The intermediate size needs to be divisible by hidden size.")
 
 
40
 
41
  if self.hidden_size % self.num_attention_heads != 0:
42
- raise ValueError("The hidden size needs to be divisible by the number of attention heads.")
 
 
43
 
44
  @property
45
  def ff_mult(self):
46
  return self.intermediate_size // self.hidden_size
47
 
48
 
49
-
50
  __all__ = ["AriaConfig"]
 
36
  self.return_dict = return_dict
37
 
38
  if self.intermediate_size % self.hidden_size != 0:
39
+ raise ValueError(
40
+ "The intermediate size needs to be divisible by hidden size."
41
+ )
42
 
43
  if self.hidden_size % self.num_attention_heads != 0:
44
+ raise ValueError(
45
+ "The hidden size needs to be divisible by the number of attention heads."
46
+ )
47
 
48
  @property
49
  def ff_mult(self):
50
  return self.intermediate_size // self.hidden_size
51
 
52
 
 
53
  __all__ = ["AriaConfig"]
modeling_aria.py CHANGED
@@ -1,7 +1,6 @@
1
  # This is lightly adapted from https://github.com/EleutherAI/aria/blob/main/aria/model.py
2
 
3
- from dataclasses import dataclass
4
- from typing import Optional, Union, Tuple, List
5
 
6
  import torch
7
  import torch.utils.checkpoint
@@ -13,7 +12,10 @@ from transformers import Cache, DynamicCache, StaticCache
13
  from transformers.utils import logging
14
  from transformers.generation import GenerationMixin
15
  from transformers.modeling_utils import PreTrainedModel
16
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
 
 
17
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
18
 
19
  from .configuration_aria import AriaConfig
@@ -94,7 +96,7 @@ class AriaBlock(nn.Module):
94
  self.norm2 = nn.LayerNorm(self.d_model)
95
 
96
  def forward(
97
- self,
98
  x: torch.Tensor,
99
  attention_mask: torch.Tensor,
100
  freqs_cis: torch.Tensor,
@@ -104,13 +106,17 @@ class AriaBlock(nn.Module):
104
  output_attentions: Optional[bool] = None,
105
  output_hidden_states: Optional[bool] = None,
106
  return_dict: Optional[bool] = None,
107
- cache_position: Optional[torch.Tensor] = None
108
  ):
109
- attn_output, attn_weights, present = self._att_block(self.norm1(x), attention_mask, freqs_cis,
110
- past_key_values=past_key_values,
111
- use_cache=use_cache,
112
- output_attentions=output_attentions,
113
- cache_position=cache_position)
 
 
 
 
114
 
115
  x = x + attn_output
116
  x = x + self._ff_block(self.norm2(x))
@@ -131,7 +137,7 @@ class AriaBlock(nn.Module):
131
  past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
132
  use_cache: Optional[bool] = None,
133
  output_attentions: Optional[bool] = None,
134
- cache_position: Optional[torch.Tensor] = None
135
  ):
136
  batch_size, seq_len, _ = x.shape
137
  mixed_qkv = self.mixed_qkv(x)
@@ -139,12 +145,8 @@ class AriaBlock(nn.Module):
139
 
140
  # Reshape for rotary embeddings
141
  # Need contiguous for q, k since in-place RoPE cannot be applied on a view
142
- xq = xq.reshape(
143
- batch_size, seq_len, self.n_heads, self.d_head
144
- ).contiguous()
145
- xk = xk.reshape(
146
- batch_size, seq_len, self.n_heads, self.d_head
147
- ).contiguous()
148
  xv = xv.view(batch_size, seq_len, self.n_heads, self.d_head)
149
 
150
  # apply_rotary_post_emb expects: (b_sz, s_len, n_head, d_head)
@@ -154,9 +156,9 @@ class AriaBlock(nn.Module):
154
 
155
  if past_key_values is not None:
156
  cache_kwargs = {
157
- #"sin": sin,
158
- #"cos": cos,
159
- #"partial_rotation_size": self.rotary_ndims,
160
  "cache_position": cache_position,
161
  }
162
  xk, xv = past_key_values.update(xk, xv, self.layer_idx, cache_kwargs)
@@ -179,10 +181,7 @@ class AriaBlock(nn.Module):
179
  return self.att_proj_linear(out), att, past_key_values
180
 
181
  def _ff_block(self, x: torch.Tensor):
182
-
183
- return self.ff_down_proj(
184
- F.silu(self.ff_gate_proj(x)) * self.ff_up_proj(x)
185
- )
186
 
187
 
188
  class AriaModel(AriaPreTrainedModel):
@@ -237,15 +236,27 @@ class AriaModel(AriaPreTrainedModel):
237
  torch.tensor: Model outputs with shape (batch_size, seq_len,
238
  d_model).
239
  """
240
- output_attentions = output_attentions if output_attentions is not None else self.model_config.output_attentions
 
 
 
 
241
  output_hidden_states = (
242
- output_hidden_states if output_hidden_states is not None else self.model_config.output_hidden_states
 
 
 
 
 
 
 
243
  )
244
- return_dict = return_dict if return_dict is not None else self.model_config.use_return_dict
245
  use_cache = use_cache if use_cache is not None else self.model_config.use_cache
246
 
247
  if (input_ids is None) ^ (inputs_embeds is not None):
248
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
 
249
 
250
  if self.gradient_checkpointing and self.training:
251
  if use_cache:
@@ -272,21 +283,32 @@ class AriaModel(AriaPreTrainedModel):
272
 
273
  seq_length = inputs_embeds.shape[1]
274
  if cache_position is None:
275
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
276
- cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
 
 
 
 
 
 
277
 
278
  if position_ids is None:
279
  position_ids = cache_position.unsqueeze(0)
280
  hidden_states = inputs_embeds
281
 
282
  causal_mask = self._update_causal_mask(
283
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
 
 
 
 
284
  )
285
 
286
  if self.freqs_cis is None:
287
  self.freqs_cis = precompute_freqs_cis(
288
  seq_len=self.model_config.max_position_embeddings,
289
- n_elem=self.model_config.hidden_size // self.model_config.num_attention_heads,
 
290
  base=500000,
291
  dtype=hidden_states.dtype,
292
  ).to(input_ids.device)
@@ -326,7 +348,9 @@ class AriaModel(AriaPreTrainedModel):
326
  for layer in self.encode_layers:
327
  if output_hidden_states:
328
  all_hidden_states = all_hidden_states + (hidden_states,)
329
- outputs = layer(hidden_states, causal_mask, freqs_cis=freqs_cis, **kwargs)
 
 
330
  hidden_states = outputs[0]
331
  if use_cache is True:
332
  next_decoder_cache = outputs[1]
@@ -342,7 +366,11 @@ class AriaModel(AriaPreTrainedModel):
342
  next_cache = next_cache.to_legacy_cache()
343
 
344
  if not return_dict:
345
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
 
 
 
 
346
 
347
  return BaseModelOutputWithPast(
348
  last_hidden_state=hidden_states,
@@ -367,11 +395,17 @@ class AriaModel(AriaPreTrainedModel):
367
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
368
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
369
  # to infer the attention mask.
370
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
371
  using_static_cache = isinstance(past_key_values, StaticCache)
372
 
373
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
374
- if self.model_config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
 
 
 
 
375
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
376
  attention_mask,
377
  inputs_embeds=input_tensor,
@@ -412,7 +446,9 @@ class AriaModel(AriaPreTrainedModel):
412
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
413
  # Details: https://github.com/pytorch/pytorch/issues/110213
414
  min_dtype = torch.finfo(dtype).min
415
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
 
 
416
 
417
  return causal_mask
418
 
@@ -434,20 +470,30 @@ class AriaModel(AriaPreTrainedModel):
434
  else:
435
  min_dtype = torch.finfo(dtype).min
436
  causal_mask = torch.full(
437
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
 
 
 
438
  )
439
  if sequence_length != 1:
440
  causal_mask = torch.triu(causal_mask, diagonal=1)
441
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
 
 
442
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
443
  if attention_mask is not None:
444
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
 
 
445
  mask_length = attention_mask.shape[-1]
446
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
447
- padding_mask = padding_mask == 0
448
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
449
- padding_mask, min_dtype
450
  )
 
 
 
 
451
 
452
  return causal_mask
453
 
@@ -483,7 +529,11 @@ class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
483
  cache_position: Optional[torch.Tensor] = None,
484
  ):
485
  """Forward pass of Transformer decoder with LM head."""
486
- return_dict = return_dict if return_dict is not None else self.model_config.use_return_dict
 
 
 
 
487
  outputs = self.model(
488
  input_ids,
489
  attention_mask=attention_mask,
@@ -507,7 +557,9 @@ class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
507
  shift_logits = lm_logits[:, :-1, :].contiguous()
508
  labels = labels[:, 1:].contiguous()
509
  loss_fct = CrossEntropyLoss()
510
- lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
 
 
511
 
512
  if not return_dict:
513
  output = (lm_logits,) + outputs[1:]
 
1
  # This is lightly adapted from https://github.com/EleutherAI/aria/blob/main/aria/model.py
2
 
3
+ from typing import Optional, Union, Tuple
 
4
 
5
  import torch
6
  import torch.utils.checkpoint
 
12
  from transformers.utils import logging
13
  from transformers.generation import GenerationMixin
14
  from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutputWithPast,
17
+ CausalLMOutputWithPast,
18
+ )
19
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
20
 
21
  from .configuration_aria import AriaConfig
 
96
  self.norm2 = nn.LayerNorm(self.d_model)
97
 
98
  def forward(
99
+ self,
100
  x: torch.Tensor,
101
  attention_mask: torch.Tensor,
102
  freqs_cis: torch.Tensor,
 
106
  output_attentions: Optional[bool] = None,
107
  output_hidden_states: Optional[bool] = None,
108
  return_dict: Optional[bool] = None,
109
+ cache_position: Optional[torch.Tensor] = None,
110
  ):
111
+ attn_output, attn_weights, present = self._att_block(
112
+ self.norm1(x),
113
+ attention_mask,
114
+ freqs_cis,
115
+ past_key_values=past_key_values,
116
+ use_cache=use_cache,
117
+ output_attentions=output_attentions,
118
+ cache_position=cache_position,
119
+ )
120
 
121
  x = x + attn_output
122
  x = x + self._ff_block(self.norm2(x))
 
137
  past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
138
  use_cache: Optional[bool] = None,
139
  output_attentions: Optional[bool] = None,
140
+ cache_position: Optional[torch.Tensor] = None,
141
  ):
142
  batch_size, seq_len, _ = x.shape
143
  mixed_qkv = self.mixed_qkv(x)
 
145
 
146
  # Reshape for rotary embeddings
147
  # Need contiguous for q, k since in-place RoPE cannot be applied on a view
148
+ xq = xq.reshape(batch_size, seq_len, self.n_heads, self.d_head).contiguous()
149
+ xk = xk.reshape(batch_size, seq_len, self.n_heads, self.d_head).contiguous()
 
 
 
 
150
  xv = xv.view(batch_size, seq_len, self.n_heads, self.d_head)
151
 
152
  # apply_rotary_post_emb expects: (b_sz, s_len, n_head, d_head)
 
156
 
157
  if past_key_values is not None:
158
  cache_kwargs = {
159
+ # "sin": sin,
160
+ # "cos": cos,
161
+ # "partial_rotation_size": self.rotary_ndims,
162
  "cache_position": cache_position,
163
  }
164
  xk, xv = past_key_values.update(xk, xv, self.layer_idx, cache_kwargs)
 
181
  return self.att_proj_linear(out), att, past_key_values
182
 
183
  def _ff_block(self, x: torch.Tensor):
184
+ return self.ff_down_proj(F.silu(self.ff_gate_proj(x)) * self.ff_up_proj(x))
 
 
 
185
 
186
 
187
  class AriaModel(AriaPreTrainedModel):
 
236
  torch.tensor: Model outputs with shape (batch_size, seq_len,
237
  d_model).
238
  """
239
+ output_attentions = (
240
+ output_attentions
241
+ if output_attentions is not None
242
+ else self.model_config.output_attentions
243
+ )
244
  output_hidden_states = (
245
+ output_hidden_states
246
+ if output_hidden_states is not None
247
+ else self.model_config.output_hidden_states
248
+ )
249
+ return_dict = (
250
+ return_dict
251
+ if return_dict is not None
252
+ else self.model_config.use_return_dict
253
  )
 
254
  use_cache = use_cache if use_cache is not None else self.model_config.use_cache
255
 
256
  if (input_ids is None) ^ (inputs_embeds is not None):
257
+ raise ValueError(
258
+ "You must specify exactly one of input_ids or inputs_embeds"
259
+ )
260
 
261
  if self.gradient_checkpointing and self.training:
262
  if use_cache:
 
283
 
284
  seq_length = inputs_embeds.shape[1]
285
  if cache_position is None:
286
+ past_seen_tokens = (
287
+ past_key_values.get_seq_length() if past_key_values is not None else 0
288
+ )
289
+ cache_position = torch.arange(
290
+ past_seen_tokens,
291
+ past_seen_tokens + seq_length,
292
+ device=inputs_embeds.device,
293
+ )
294
 
295
  if position_ids is None:
296
  position_ids = cache_position.unsqueeze(0)
297
  hidden_states = inputs_embeds
298
 
299
  causal_mask = self._update_causal_mask(
300
+ attention_mask,
301
+ inputs_embeds,
302
+ cache_position,
303
+ past_key_values,
304
+ output_attentions,
305
  )
306
 
307
  if self.freqs_cis is None:
308
  self.freqs_cis = precompute_freqs_cis(
309
  seq_len=self.model_config.max_position_embeddings,
310
+ n_elem=self.model_config.hidden_size
311
+ // self.model_config.num_attention_heads,
312
  base=500000,
313
  dtype=hidden_states.dtype,
314
  ).to(input_ids.device)
 
348
  for layer in self.encode_layers:
349
  if output_hidden_states:
350
  all_hidden_states = all_hidden_states + (hidden_states,)
351
+ outputs = layer(
352
+ hidden_states, causal_mask, freqs_cis=freqs_cis, **kwargs
353
+ )
354
  hidden_states = outputs[0]
355
  if use_cache is True:
356
  next_decoder_cache = outputs[1]
 
366
  next_cache = next_cache.to_legacy_cache()
367
 
368
  if not return_dict:
369
+ return tuple(
370
+ v
371
+ for v in [hidden_states, next_cache, all_hidden_states, all_attentions]
372
+ if v is not None
373
+ )
374
 
375
  return BaseModelOutputWithPast(
376
  last_hidden_state=hidden_states,
 
395
  # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
396
  # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
397
  # to infer the attention mask.
398
+ past_seen_tokens = (
399
+ past_key_values.get_seq_length() if past_key_values is not None else 0
400
+ )
401
  using_static_cache = isinstance(past_key_values, StaticCache)
402
 
403
  # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
404
+ if (
405
+ self.model_config._attn_implementation == "sdpa"
406
+ and not using_static_cache
407
+ and not output_attentions
408
+ ):
409
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
410
  attention_mask,
411
  inputs_embeds=input_tensor,
 
446
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
447
  # Details: https://github.com/pytorch/pytorch/issues/110213
448
  min_dtype = torch.finfo(dtype).min
449
+ causal_mask = AttentionMaskConverter._unmask_unattended(
450
+ causal_mask, min_dtype
451
+ )
452
 
453
  return causal_mask
454
 
 
470
  else:
471
  min_dtype = torch.finfo(dtype).min
472
  causal_mask = torch.full(
473
+ (sequence_length, target_length),
474
+ fill_value=min_dtype,
475
+ dtype=dtype,
476
+ device=device,
477
  )
478
  if sequence_length != 1:
479
  causal_mask = torch.triu(causal_mask, diagonal=1)
480
+ causal_mask *= torch.arange(
481
+ target_length, device=device
482
+ ) > cache_position.reshape(-1, 1)
483
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
484
  if attention_mask is not None:
485
+ causal_mask = (
486
+ causal_mask.clone()
487
+ ) # copy to contiguous memory for in-place edit
488
  mask_length = attention_mask.shape[-1]
489
+ padding_mask = (
490
+ causal_mask[:, :, :, :mask_length]
491
+ + attention_mask[:, None, None, :]
 
492
  )
493
+ padding_mask = padding_mask == 0
494
+ causal_mask[:, :, :, :mask_length] = causal_mask[
495
+ :, :, :, :mask_length
496
+ ].masked_fill(padding_mask, min_dtype)
497
 
498
  return causal_mask
499
 
 
529
  cache_position: Optional[torch.Tensor] = None,
530
  ):
531
  """Forward pass of Transformer decoder with LM head."""
532
+ return_dict = (
533
+ return_dict
534
+ if return_dict is not None
535
+ else self.model_config.use_return_dict
536
+ )
537
  outputs = self.model(
538
  input_ids,
539
  attention_mask=attention_mask,
 
557
  shift_logits = lm_logits[:, :-1, :].contiguous()
558
  labels = labels[:, 1:].contiguous()
559
  loss_fct = CrossEntropyLoss()
560
+ lm_loss = loss_fct(
561
+ shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
562
+ )
563
 
564
  if not return_dict:
565
  output = (lm_logits,) + outputs[1:]
tokenization_aria.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, List, Optional, Tuple
2
+
3
+ from transformers.tokenization_utils import PreTrainedTokenizer, BatchEncoding
4
+ from transformers.utils import logging, TensorType, to_py_obj
5
+
6
+ try:
7
+ from ariautils.midi import MidiDict
8
+ from ariautils.tokenizer import AbsTokenizer
9
+ from ariautils.tokenizer._base import Token
10
+ except ImportError:
11
+ raise ImportError(
12
+ "ariautils is not installed. Please try `pip install git+https://github.com/EleutherAI/aria-utils.git`."
13
+ )
14
+
15
+ if TYPE_CHECKING:
16
+ pass
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class AriaTokenizer(PreTrainedTokenizer):
22
+ """
23
+ Aria Tokenizer is NOT a BPE tokenizer. A midi file will be converted to a MidiDict (note: in fact, a MidiDict is not a single dict. It is more about a list of "notes") which represents a sequence of notes, stops, etc. And then, aria tokenizer is simply a dictionary that maps MidiDict to discrete indices according to a hard-coded rule.
24
+
25
+ For a FIM finetuned model, we also follow a simple FIM format to guide a piece of music to a (possibly very different) suffix according to the prompts:
26
+ <GUIDANCE-START> ... <GUIDANCE-END> <S> <PROMPT-START> ... <PROMPT-END>
27
+ This way, we expect a continuation that connects PROMPT and GUIDANCE.
28
+ """
29
+
30
+ vocab_files_names = {}
31
+ model_input_names = ["input_ids", "attention_mask"]
32
+
33
+ def __init__(
34
+ self,
35
+ add_bos_token=True,
36
+ add_eos_token=False,
37
+ clean_up_tokenization_spaces=False,
38
+ use_default_system_prompt=False,
39
+ **kwargs,
40
+ ):
41
+ self._tokenizer = AbsTokenizer()
42
+
43
+ self.add_bos_token = add_bos_token
44
+ self.add_eos_token = add_eos_token
45
+ self.use_default_system_prompt = use_default_system_prompt
46
+
47
+ bos_token = self._tokenizer.bos_tok
48
+ eos_token = self._tokenizer.eos_tok
49
+ pad_token = self._tokenizer.pad_tok
50
+ unk_token = self._tokenizer.unk_tok
51
+
52
+ super().__init__(
53
+ bos_token=bos_token,
54
+ eos_token=eos_token,
55
+ unk_token=unk_token,
56
+ pad_token=pad_token,
57
+ use_default_system_prompt=use_default_system_prompt,
58
+ **kwargs,
59
+ )
60
+
61
+ def __getstate__(self):
62
+ return {}
63
+
64
+ def __setstate__(self, d):
65
+ raise NotImplementedError()
66
+
67
+ @property
68
+ def vocab_size(self):
69
+ """Returns vocab size"""
70
+ return self._tokenizer.vocab_size
71
+
72
+ def get_vocab(self):
73
+ return self._tokenizer.tok_to_id
74
+
75
+ def tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]:
76
+ return self._tokenizer(midi_dict)
77
+
78
+ def _tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]:
79
+ return self._tokenizer(midi_dict)
80
+
81
+ def __call__(
82
+ self,
83
+ midi_dicts: MidiDict | list[MidiDict],
84
+ padding: bool = False,
85
+ max_length: int | None = None,
86
+ pad_to_multiple_of: int | None = None,
87
+ return_tensors: str | TensorType | None = None,
88
+ return_attention_mask: bool | None = None,
89
+ **kwargs,
90
+ ) -> BatchEncoding:
91
+ """It is impossible to rely on the parent method because the inputs are MidiDict(s) instead of strings. I do not like the idea of going hacky so that two entirely different types of inputs can marry. So here I reimplement __call__ with limited support of certain useful arguments. I do not expect any conflict with other "string-in-ids-out" tokenizers. If you have to mix up the API of string-based tokenizers and our midi-based tokenizer, there must be a problem with your design."""
92
+ if isinstance(midi_dicts, MidiDict):
93
+ midi_dicts = [midi_dicts]
94
+
95
+ all_tokens: list[list[int]] = []
96
+ all_attn_masks: list[list[int]] = []
97
+ max_len_encoded = 0
98
+ # TODO: if we decide to optimize batched tokenization on ariautils using some compiled backend, we can change this loop accordingly.
99
+ for md in midi_dicts:
100
+ tokens = self._tokenizer.encode(self._tokenizer.tokenize(md))
101
+ if max_length is not None:
102
+ tokens = tokens[:max_length]
103
+ max_len_encoded = max(max_len_encoded, len(tokens))
104
+ all_tokens.append(tokens)
105
+ all_attn_masks.append([True] * len(tokens))
106
+
107
+ if pad_to_multiple_of is not None:
108
+ max_len_encoded = (
109
+ (max_len_encoded + pad_to_multiple_of) // pad_to_multiple_of
110
+ ) * pad_to_multiple_of
111
+ if padding:
112
+ for tokens, attn_mask in zip(all_tokens, all_attn_masks):
113
+ tokens.extend([self.pad_token_id] * (max_len_encoded - len(tokens)))
114
+ attn_mask.extend([False] * (max_len_encoded - len(tokens)))
115
+
116
+ return BatchEncoding(
117
+ {
118
+ "input_ids": all_tokens,
119
+ "attention_masks": all_attn_masks,
120
+ },
121
+ tensor_type=return_tensors,
122
+ )
123
+
124
+ def decode(self, token_ids: List[Token], **kwargs) -> MidiDict:
125
+ token_ids = to_py_obj(token_ids)
126
+
127
+ return self._tokenizer.detokenize(self._tokenizer.decode(token_ids))
128
+
129
+ def batch_decode(
130
+ self, token_ids_list: List[List[Token]], **kwargs
131
+ ) -> List[MidiDict]:
132
+ results = []
133
+ for token_ids in token_ids_list:
134
+ # Can we simply yield (without breaking all HF wrappers)?
135
+ results.append(self.decode(token_ids))
136
+ return results
137
+
138
+ def encode_from_file(self, filename: str, **kwargs) -> BatchEncoding:
139
+ midi_dict = MidiDict.from_midi(filename)
140
+ return self(midi_dict, **kwargs)
141
+
142
+ def encode_from_files(self, filenames: list[str], **kwargs) -> BatchEncoding:
143
+ midi_dicts = [MidiDict.from_midi(file) for file in filenames]
144
+ return self(midi_dicts, **kwargs)
145
+
146
+ def _convert_token_to_id(self, token: Token):
147
+ """Converts a token (tuple or str) into an id."""
148
+ return self._tokenizer.tok_to_id.get(
149
+ token, self._tokenizer.tok_to_id[self.unk_token]
150
+ )
151
+
152
+ def _convert_id_to_token(self, index: int):
153
+ """Converts an index (integer) in a token (tuple or str)."""
154
+ return self._tokenizer.id_to_tok.get(index, self.unk_token)
155
+
156
+ def convert_tokens_to_string(self, tokens: List[Token]) -> MidiDict:
157
+ """Converts a sequence of tokens into a single MidiDict."""
158
+ return self._tokenizer.detokenize(tokens)
159
+
160
+ def save_vocabulary(
161
+ self, save_directory, filename_prefix: Optional[str] = None
162
+ ) -> Tuple[str]:
163
+ raise NotImplementedError()
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "auto_map": {
5
+ "AutoTokenizer": [
6
+ "tokenization_aria.AriaTokenizer",
7
+ null
8
+ ]
9
+ },
10
+ "tokenizer_class": "AriaTokenizer"
11
+ }