medium-e75-base-padded / tokenization_aria.py
quintic's picture
add tokenizer; reformat
fd1489d
raw
history blame
6.5 kB
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers.tokenization_utils import PreTrainedTokenizer, BatchEncoding
from transformers.utils import logging, TensorType, to_py_obj
try:
from ariautils.midi import MidiDict
from ariautils.tokenizer import AbsTokenizer
from ariautils.tokenizer._base import Token
except ImportError:
raise ImportError(
"ariautils is not installed. Please try `pip install git+https://github.com/EleutherAI/aria-utils.git`."
)
if TYPE_CHECKING:
pass
logger = logging.get_logger(__name__)
class AriaTokenizer(PreTrainedTokenizer):
"""
Aria Tokenizer is NOT a BPE tokenizer. A midi file will be converted to a MidiDict (note: in fact, a MidiDict is not a single dict. It is more about a list of "notes") which represents a sequence of notes, stops, etc. And then, aria tokenizer is simply a dictionary that maps MidiDict to discrete indices according to a hard-coded rule.
For a FIM finetuned model, we also follow a simple FIM format to guide a piece of music to a (possibly very different) suffix according to the prompts:
<GUIDANCE-START> ... <GUIDANCE-END> <S> <PROMPT-START> ... <PROMPT-END>
This way, we expect a continuation that connects PROMPT and GUIDANCE.
"""
vocab_files_names = {}
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
use_default_system_prompt=False,
**kwargs,
):
self._tokenizer = AbsTokenizer()
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt
bos_token = self._tokenizer.bos_tok
eos_token = self._tokenizer.eos_tok
pad_token = self._tokenizer.pad_tok
unk_token = self._tokenizer.unk_tok
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
use_default_system_prompt=use_default_system_prompt,
**kwargs,
)
def __getstate__(self):
return {}
def __setstate__(self, d):
raise NotImplementedError()
@property
def vocab_size(self):
"""Returns vocab size"""
return self._tokenizer.vocab_size
def get_vocab(self):
return self._tokenizer.tok_to_id
def tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]:
return self._tokenizer(midi_dict)
def _tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]:
return self._tokenizer(midi_dict)
def __call__(
self,
midi_dicts: MidiDict | list[MidiDict],
padding: bool = False,
max_length: int | None = None,
pad_to_multiple_of: int | None = None,
return_tensors: str | TensorType | None = None,
return_attention_mask: bool | None = None,
**kwargs,
) -> BatchEncoding:
"""It is impossible to rely on the parent method because the inputs are MidiDict(s) instead of strings. I do not like the idea of going hacky so that two entirely different types of inputs can marry. So here I reimplement __call__ with limited support of certain useful arguments. I do not expect any conflict with other "string-in-ids-out" tokenizers. If you have to mix up the API of string-based tokenizers and our midi-based tokenizer, there must be a problem with your design."""
if isinstance(midi_dicts, MidiDict):
midi_dicts = [midi_dicts]
all_tokens: list[list[int]] = []
all_attn_masks: list[list[int]] = []
max_len_encoded = 0
# TODO: if we decide to optimize batched tokenization on ariautils using some compiled backend, we can change this loop accordingly.
for md in midi_dicts:
tokens = self._tokenizer.encode(self._tokenizer.tokenize(md))
if max_length is not None:
tokens = tokens[:max_length]
max_len_encoded = max(max_len_encoded, len(tokens))
all_tokens.append(tokens)
all_attn_masks.append([True] * len(tokens))
if pad_to_multiple_of is not None:
max_len_encoded = (
(max_len_encoded + pad_to_multiple_of) // pad_to_multiple_of
) * pad_to_multiple_of
if padding:
for tokens, attn_mask in zip(all_tokens, all_attn_masks):
tokens.extend([self.pad_token_id] * (max_len_encoded - len(tokens)))
attn_mask.extend([False] * (max_len_encoded - len(tokens)))
return BatchEncoding(
{
"input_ids": all_tokens,
"attention_masks": all_attn_masks,
},
tensor_type=return_tensors,
)
def decode(self, token_ids: List[Token], **kwargs) -> MidiDict:
token_ids = to_py_obj(token_ids)
return self._tokenizer.detokenize(self._tokenizer.decode(token_ids))
def batch_decode(
self, token_ids_list: List[List[Token]], **kwargs
) -> List[MidiDict]:
results = []
for token_ids in token_ids_list:
# Can we simply yield (without breaking all HF wrappers)?
results.append(self.decode(token_ids))
return results
def encode_from_file(self, filename: str, **kwargs) -> BatchEncoding:
midi_dict = MidiDict.from_midi(filename)
return self(midi_dict, **kwargs)
def encode_from_files(self, filenames: list[str], **kwargs) -> BatchEncoding:
midi_dicts = [MidiDict.from_midi(file) for file in filenames]
return self(midi_dicts, **kwargs)
def _convert_token_to_id(self, token: Token):
"""Converts a token (tuple or str) into an id."""
return self._tokenizer.tok_to_id.get(
token, self._tokenizer.tok_to_id[self.unk_token]
)
def _convert_id_to_token(self, index: int):
"""Converts an index (integer) in a token (tuple or str)."""
return self._tokenizer.id_to_tok.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[Token]) -> MidiDict:
"""Converts a sequence of tokens into a single MidiDict."""
return self._tokenizer.detokenize(tokens)
def save_vocabulary(
self, save_directory, filename_prefix: Optional[str] = None
) -> Tuple[str]:
raise NotImplementedError()