|
from typing import TYPE_CHECKING, List, Optional, Tuple |
|
|
|
from transformers.tokenization_utils import PreTrainedTokenizer, BatchEncoding |
|
from transformers.utils import logging, TensorType, to_py_obj |
|
|
|
try: |
|
from ariautils.midi import MidiDict |
|
from ariautils.tokenizer import AbsTokenizer |
|
from ariautils.tokenizer._base import Token |
|
except ImportError: |
|
raise ImportError( |
|
"ariautils is not installed. Please try `pip install git+https://github.com/EleutherAI/aria-utils.git`." |
|
) |
|
|
|
if TYPE_CHECKING: |
|
pass |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class AriaTokenizer(PreTrainedTokenizer): |
|
""" |
|
Aria Tokenizer is NOT a BPE tokenizer. A midi file will be converted to a MidiDict (note: in fact, a MidiDict is not a single dict. It is more about a list of "notes") which represents a sequence of notes, stops, etc. And then, aria tokenizer is simply a dictionary that maps MidiDict to discrete indices according to a hard-coded rule. |
|
|
|
For a FIM finetuned model, we also follow a simple FIM format to guide a piece of music to a (possibly very different) suffix according to the prompts: |
|
<GUIDANCE-START> ... <GUIDANCE-END> <S> <PROMPT-START> ... <PROMPT-END> |
|
This way, we expect a continuation that connects PROMPT and GUIDANCE. |
|
""" |
|
|
|
vocab_files_names = {} |
|
model_input_names = ["input_ids", "attention_mask"] |
|
|
|
def __init__( |
|
self, |
|
add_bos_token=True, |
|
add_eos_token=False, |
|
clean_up_tokenization_spaces=False, |
|
use_default_system_prompt=False, |
|
**kwargs, |
|
): |
|
self._tokenizer = AbsTokenizer() |
|
|
|
self.add_bos_token = add_bos_token |
|
self.add_eos_token = add_eos_token |
|
self.use_default_system_prompt = use_default_system_prompt |
|
|
|
bos_token = self._tokenizer.bos_tok |
|
eos_token = self._tokenizer.eos_tok |
|
pad_token = self._tokenizer.pad_tok |
|
unk_token = self._tokenizer.unk_tok |
|
|
|
super().__init__( |
|
bos_token=bos_token, |
|
eos_token=eos_token, |
|
unk_token=unk_token, |
|
pad_token=pad_token, |
|
use_default_system_prompt=use_default_system_prompt, |
|
**kwargs, |
|
) |
|
|
|
def __getstate__(self): |
|
return {} |
|
|
|
def __setstate__(self, d): |
|
raise NotImplementedError() |
|
|
|
@property |
|
def vocab_size(self): |
|
"""Returns vocab size""" |
|
return self._tokenizer.vocab_size |
|
|
|
def get_vocab(self): |
|
return self._tokenizer.tok_to_id |
|
|
|
def tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]: |
|
return self._tokenizer(midi_dict) |
|
|
|
def _tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]: |
|
return self._tokenizer(midi_dict) |
|
|
|
def __call__( |
|
self, |
|
midi_dicts: MidiDict | list[MidiDict], |
|
padding: bool = False, |
|
max_length: int | None = None, |
|
pad_to_multiple_of: int | None = None, |
|
return_tensors: str | TensorType | None = None, |
|
return_attention_mask: bool | None = None, |
|
**kwargs, |
|
) -> BatchEncoding: |
|
"""It is impossible to rely on the parent method because the inputs are MidiDict(s) instead of strings. I do not like the idea of going hacky so that two entirely different types of inputs can marry. So here I reimplement __call__ with limited support of certain useful arguments. I do not expect any conflict with other "string-in-ids-out" tokenizers. If you have to mix up the API of string-based tokenizers and our midi-based tokenizer, there must be a problem with your design.""" |
|
if isinstance(midi_dicts, MidiDict): |
|
midi_dicts = [midi_dicts] |
|
|
|
all_tokens: list[list[int]] = [] |
|
all_attn_masks: list[list[int]] = [] |
|
max_len_encoded = 0 |
|
|
|
for md in midi_dicts: |
|
tokens = self._tokenizer.encode(self._tokenizer.tokenize(md)) |
|
if max_length is not None: |
|
tokens = tokens[:max_length] |
|
max_len_encoded = max(max_len_encoded, len(tokens)) |
|
all_tokens.append(tokens) |
|
all_attn_masks.append([True] * len(tokens)) |
|
|
|
if pad_to_multiple_of is not None: |
|
max_len_encoded = ( |
|
(max_len_encoded + pad_to_multiple_of) // pad_to_multiple_of |
|
) * pad_to_multiple_of |
|
if padding: |
|
for tokens, attn_mask in zip(all_tokens, all_attn_masks): |
|
tokens.extend([self.pad_token_id] * (max_len_encoded - len(tokens))) |
|
attn_mask.extend([False] * (max_len_encoded - len(tokens))) |
|
|
|
return BatchEncoding( |
|
{ |
|
"input_ids": all_tokens, |
|
"attention_masks": all_attn_masks, |
|
}, |
|
tensor_type=return_tensors, |
|
) |
|
|
|
def decode(self, token_ids: List[Token], **kwargs) -> MidiDict: |
|
token_ids = to_py_obj(token_ids) |
|
|
|
return self._tokenizer.detokenize(self._tokenizer.decode(token_ids)) |
|
|
|
def batch_decode( |
|
self, token_ids_list: List[List[Token]], **kwargs |
|
) -> List[MidiDict]: |
|
results = [] |
|
for token_ids in token_ids_list: |
|
|
|
results.append(self.decode(token_ids)) |
|
return results |
|
|
|
def encode_from_file(self, filename: str, **kwargs) -> BatchEncoding: |
|
midi_dict = MidiDict.from_midi(filename) |
|
return self(midi_dict, **kwargs) |
|
|
|
def encode_from_files(self, filenames: list[str], **kwargs) -> BatchEncoding: |
|
midi_dicts = [MidiDict.from_midi(file) for file in filenames] |
|
return self(midi_dicts, **kwargs) |
|
|
|
def _convert_token_to_id(self, token: Token): |
|
"""Converts a token (tuple or str) into an id.""" |
|
return self._tokenizer.tok_to_id.get( |
|
token, self._tokenizer.tok_to_id[self.unk_token] |
|
) |
|
|
|
def _convert_id_to_token(self, index: int): |
|
"""Converts an index (integer) in a token (tuple or str).""" |
|
return self._tokenizer.id_to_tok.get(index, self.unk_token) |
|
|
|
def convert_tokens_to_string(self, tokens: List[Token]) -> MidiDict: |
|
"""Converts a sequence of tokens into a single MidiDict.""" |
|
return self._tokenizer.detokenize(tokens) |
|
|
|
def save_vocabulary( |
|
self, save_directory, filename_prefix: Optional[str] = None |
|
) -> Tuple[str]: |
|
raise NotImplementedError() |
|
|