|
from transformers import PreTrainedTokenizer |
|
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, decoders |
|
import json |
|
from typing import List, Optional, Union, Dict |
|
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding |
|
from transformers.utils import PaddingStrategy |
|
|
|
class OBITokenizer(PreTrainedTokenizer): |
|
def __init__( |
|
self, |
|
unk_token="<unk>", |
|
bos_token="<s>", |
|
eos_token="</s>", |
|
pad_token=None, |
|
add_bos_token=True, |
|
add_eos_token=False, |
|
clean_up_tokenization_spaces=False, |
|
auto_map={"AutoTokenizer": ["tokenizeConfig.OBITokenizer"]}, |
|
tokenizer_class="OBITokenizer", |
|
**kwargs, |
|
): |
|
super().__init__( |
|
unk_token=unk_token, |
|
bos_token=bos_token, |
|
eos_token=eos_token, |
|
pad_token=pad_token, |
|
add_bos_token=add_bos_token, |
|
add_eos_token=add_eos_token, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
**kwargs, |
|
) |
|
|
|
|
|
bpe_model = models.BPE() |
|
self.tokenizer = Tokenizer(bpe_model) |
|
|
|
|
|
self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel() |
|
self.tokenizer.decoder = decoders.ByteLevel() |
|
|
|
|
|
self.pad_token = "[PAD]" |
|
|
|
|
|
self.cls_token = "[CLS]" |
|
self.sep_token = "[SEP]" |
|
self.unk_token = "[UNK]" |
|
self.mask_token = "[MASK]" |
|
self.bos_token = "[CLS]" |
|
self.eos_token = "[SEP]" |
|
self.pad_token = "[PAD]" |
|
|
|
|
|
self.tokenizer.get_vocab().add_special_tokens([self.cls_token, self.sep_token, self.unk_token, self.mask_token]) |
|
|
|
def add_special_tokens(self, special_tokens_dict): |
|
|
|
|
|
return self.tokenizer.get_vocab().add_special_tokens(special_tokens_dict) |
|
|
|
def _tokenize(self, text): |
|
|
|
|
|
return text.split() |
|
|
|
def _convert_token_to_id(self, token): |
|
|
|
|
|
return self.token_to_id_mapping.get(token, self.tokenizer.convert_tokens_to_ids(self.unk_token)) |
|
|
|
def _convert_id_to_token(self, index): |
|
|
|
|
|
return self.id_to_token_mapping.get(index, self.unk_token) |
|
|
|
def encode(self, text): |
|
|
|
input_ids = [self._convert_token_to_id(token) for token in self._tokenize(text)] |
|
attention_mask = [1] * len(input_ids) |
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
|
def decode(self, ids): |
|
|
|
tokens = [self._convert_id_to_token(token_id) for token_id in ids] |
|
return " ".join(tokens) |
|
|
|
def get_vocab(self): |
|
|
|
return self.tokenizer.get_vocab() |
|
|
|
def save_vocabulary(self, vocab_path): |
|
|
|
with open(vocab_path, "w") as f: |
|
json.dump(self.tokenizer.get_vocab(), f) |
|
|