File size: 672 Bytes
4ffd891 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
from transformers import PretrainedConfig
class MultiHeadConfig(PretrainedConfig):
model_type = "multihead"
def __init__(
self,
encoder_name="microsoft/deberta-v3-small",
**kwargs
):
self.encoder_name = encoder_name
self.classifier_dropout = kwargs.get("classifier_dropout", 0.1)
self.num_labels = kwargs.get("num_labels", 2)
self.id2label = kwargs.get("id2label", {0: "irrelevant", 1: "relevant"})
self.label2id = kwargs.get("label2id", {"irrelevant": 0, "relevant": 1})
self.tokenizer_class = kwargs.get("tokenizer_class", "DebertaV2TokenizerFast")
super().__init__(**kwargs)
|