Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2018 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Auto Model class.""" | |
import warnings | |
from collections import OrderedDict | |
from transformers.utils import logging | |
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update | |
from .configuration_auto import CONFIG_MAPPING_NAMES | |
logger = logging.get_logger(__name__) | |
MODEL_MAPPING_NAMES = OrderedDict( | |
[ | |
# Base model mapping | |
("roformer", "RoFormerModel"), | |
("longformer", "LongformerModel"), | |
] | |
) | |
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for pre-training mapping | |
("longformer", "LongformerForMaskedLM"), | |
] | |
) | |
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model with LM heads mapping | |
("roformer", "RoFormerForMaskedLM"), | |
("longformer", "LongformerForMaskedLM"), | |
] | |
) | |
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Causal LM mapping | |
("roformer", "RoFormerForCausalLM"), | |
] | |
) | |
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Masked LM mapping | |
("roformer", "RoFormerForMaskedLM"), | |
("longformer", "LongformerForMaskedLM"), | |
] | |
) | |
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Seq2Seq Causal LM mapping | |
("t5", "T5ForConditionalGeneration"), | |
] | |
) | |
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( | |
[ | |
("speech-encoder-decoder", "SpeechEncoderDecoderModel"), | |
("speech_to_text", "Speech2TextForConditionalGeneration"), | |
] | |
) | |
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Sequence Classification mapping | |
("roformer", "RoFormerForSequenceClassification"), | |
("longformer", "LongformerForSequenceClassification"), | |
] | |
) | |
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Question Answering mapping | |
("roformer", "RoFormerForQuestionAnswering"), | |
("longformer", "LongformerForQuestionAnswering"), | |
] | |
) | |
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Table Question Answering mapping | |
("tapas", "TapasForQuestionAnswering"), | |
] | |
) | |
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Token Classification mapping | |
("roformer", "RoFormerForTokenClassification"), | |
("longformer", "LongformerForTokenClassification"), | |
] | |
) | |
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( | |
[ | |
# Model for Multiple Choice mapping | |
("roformer", "RoFormerForMultipleChoice"), | |
("longformer", "LongformerForMultipleChoice"), | |
] | |
) | |
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) | |
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) | |
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) | |
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) | |
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) | |
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES | |
) | |
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES | |
) | |
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES | |
) | |
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES | |
) | |
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES | |
) | |
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) | |
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) | |
class AutoModel(_BaseAutoModelClass): | |
_model_mapping = MODEL_MAPPING | |
AutoModel = auto_class_update(AutoModel) | |
class AutoModelForPreTraining(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING | |
AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") | |
# Private on purpose, the public class will add the deprecation warnings. | |
class _AutoModelWithLMHead(_BaseAutoModelClass): | |
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING | |
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") | |
class AutoModelForCausalLM(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING | |
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") | |
class AutoModelForMaskedLM(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING | |
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") | |
class AutoModelForSeq2SeqLM(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING | |
AutoModelForSeq2SeqLM = auto_class_update( | |
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" | |
) | |
class AutoModelForSequenceClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING | |
AutoModelForSequenceClassification = auto_class_update( | |
AutoModelForSequenceClassification, head_doc="sequence classification" | |
) | |
class AutoModelForQuestionAnswering(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING | |
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") | |
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING | |
AutoModelForTableQuestionAnswering = auto_class_update( | |
AutoModelForTableQuestionAnswering, | |
head_doc="table question answering", | |
checkpoint_for_example="google/tapas-base-finetuned-wtq", | |
) | |
class AutoModelForTokenClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") | |
class AutoModelForMultipleChoice(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING | |
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") | |
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING | |
AutoModelForSpeechSeq2Seq = auto_class_update( | |
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing" | |
) | |
class AutoModelWithLMHead(_AutoModelWithLMHead): | |
def from_config(cls, config): | |
warnings.warn( | |
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |
FutureWarning, | |
) | |
return super().from_config(config) | |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
warnings.warn( | |
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |
FutureWarning, | |
) | |
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | |