File size: 672 Bytes
4ffd891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import PretrainedConfig

class MultiHeadConfig(PretrainedConfig):
    model_type = "multihead"

    def __init__(
        self,
        encoder_name="microsoft/deberta-v3-small",
        **kwargs
    ):
        self.encoder_name = encoder_name
        self.classifier_dropout = kwargs.get("classifier_dropout", 0.1)
        self.num_labels = kwargs.get("num_labels", 2)
        self.id2label = kwargs.get("id2label", {0: "irrelevant", 1: "relevant"})
        self.label2id = kwargs.get("label2id", {"irrelevant": 0, "relevant": 1})
        self.tokenizer_class = kwargs.get("tokenizer_class", "DebertaV2TokenizerFast")
        super().__init__(**kwargs)