File size: 3,797 Bytes
2d0282e
88541c2
 
ac72af3
 
9d6689e
88541c2
b5f3842
9d6689e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88541c2
 
 
9d6689e
88541c2
 
 
 
4e47b59
 
9d6689e
4e47b59
 
 
 
 
 
 
 
9d6689e
 
 
ac72af3
 
9d6689e
 
 
 
 
 
ac72af3
 
 
 
 
9d6689e
88541c2
 
 
 
 
 
 
 
 
 
 
 
 
9d6689e
88541c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d6689e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from transformers import PreTrainedTokenizer
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, decoders
import json
from typing import List, Optional, Union, Dict
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
from transformers.utils import PaddingStrategy

class OBITokenizer(PreTrainedTokenizer):
    def __init__(
        self,
        vocab_file,
        unk_token="<unk>",
        bos_token="<s>",
        eos_token="</s>",
        pad_token=None,
        add_bos_token=True,
        add_eos_token=False,
        clean_up_tokenization_spaces=False,
        auto_map={"AutoTokenizer": ["tokenizeConfig.OBITokenizer"]},
        tokenizer_class="OBITokenizer",
        **kwargs,
    ):
        super().__init__(
            unk_token=unk_token,
            bos_token=bos_token,
            eos_token=eos_token,
            pad_token=pad_token,
            add_bos_token=add_bos_token,
            add_eos_token=add_eos_token,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            **kwargs,
        )

        # Initialize a BPE model for tokenization
        bpe_model = models.BPE()
        self.tokenizer = Tokenizer(bpe_model)

        # Add pre-tokenization and decoding steps if needed
        self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
        self.tokenizer.decoder = decoders.ByteLevel()

        # Set the padding token
        self.pad_token = "[PAD]"

        # Set the special tokens
        self.cls_token = "[CLS]"
        self.sep_token = "[SEP]"
        self.unk_token = "[UNK]"
        self.mask_token = "[MASK]"
        self.bos_token = "[CLS]"
        self.eos_token = "[SEP]"
        self.pad_token = "[PAD]"

        # Load the vocabulary file
        self.tokenizer.get_vocab().add_special_tokens([self.cls_token, self.sep_token, self.unk_token, self.mask_token])

    def _pad(
        self,
        encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
        max_length: Optional[int] = None,
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
        pad_to_multiple_of: Optional[int] = None,
        return_attention_mask: Optional[bool] = None,
    ) -> dict:
        # Modify the _pad method as needed for OBITokenizer
        # You can inherit the implementation from ChatGLMTokenizer and customize it further
        return super()._pad(encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask)

    def train(self, files, save_path):
        # Training: Fit the tokenizer on your text data
        trainer = trainers.BpeTrainer(special_tokens=["[PAD]", "[CLS]", "[SEP]", "[MASK]", "[UNK]"])
        self.tokenizer.train(trainer=trainer, files=files)
        # Save the trained tokenizer to a file
        self.tokenizer.save(save_path)

    def save_config(self, config_file):
        # Serialize the tokenizer's config to a JSON file
        config_dict = {
            "tokenizer_type": "custom",
            "vocab_size": self.tokenizer.get_vocab_size(),
            "tokenizer_class": "OBITokenizer",
            "auto_map": {
                "AutoTokenizer": ["tokenizeConfig.OBITokenizer", "null"]
            },
            "bos_token": "[CLS]",
            "eos_token": "[SEP]",
            "unk_token": "[UNK]",
            "pad_token": "[PAD]",
            "mask_token": "[MASK]"
            # Add other custom settings if needed
        }
        with open(config_file, "w") as f:
            json.dump(config_dict, f)

    def encode(self, text):
        # Encode text using the custom tokenizer
        encoding = self.tokenizer.encode(text)
        return encoding.ids

    def decode(self, ids):
        # Decode IDs to text using the custom tokenizer
        return self.tokenizer.decode(ids)