File size: 1,785 Bytes
88541c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, decoders
import json

class OBITokenizer:
    def __init__(self):
        # Initialize a BPE model for tokenization
        bpe_model = models.BPE()
        # Initialize the tokenizer
        self.tokenizer = Tokenizer(bpe_model)
        # Add pre-tokenization and decoding steps if needed
        self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
        self.tokenizer.decoder = decoders.ByteLevel()

    def train(self, files,save_path):
        # Training: Fit the tokenizer on your text data
        trainer = trainers.BpeTrainer(special_tokens=["[PAD]", "[CLS]", "[SEP]", "[MASK]", "[UNK]"])
        self.tokenizer.train(trainer=trainer, files=files)
        # Save the trained tokenizer to a file
        self.tokenizer.save(save_path)



    def save_config(self, config_file):
        # Serialize the tokenizer's config to a JSON file
        config_dict = {
            "tokenizer_type": "custom",
            "vocab_size": self.tokenizer.get_vocab_size(),
            "tokenizer_class": "OBITokenizer",
            "auto_map": {
                "AutoTokenizer": ["tokenizeConfig.OBITokenizer"]
            },
            "bos_token": "[CLS]",
            "eos_token": "[SEP]",
            "unk_token": "[UNK]",
            "pad_token": "[PAD]",
            "mask_token": "[MASK]"
            # Add other custom settings if needed
        }
        with open(config_file, "w") as f:
            json.dump(config_dict, f)

    def encode(self, text):
        # Encode text using the custom tokenizer
        encoding = self.tokenizer.encode(text)
        return encoding.ids

    def decode(self, ids):
        # Decode IDs to text using the custom tokenizer
        return self.tokenizer.decode(ids)