# coding=utf-8 # Copyright 2018 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Auto Model class.""" import warnings from collections import OrderedDict from transformers.utils import logging from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update from .configuration_auto import CONFIG_MAPPING_NAMES logger = logging.get_logger(__name__) MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping ("roformer", "RoFormerModel"), ("longformer", "LongformerModel"), ] ) MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( [ # Model for pre-training mapping ("longformer", "LongformerForMaskedLM"), ] ) MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping ("roformer", "RoFormerForMaskedLM"), ("longformer", "LongformerForMaskedLM"), ] ) MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping ("roformer", "RoFormerForCausalLM"), ] ) MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping ("roformer", "RoFormerForMaskedLM"), ("longformer", "LongformerForMaskedLM"), ] ) MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping ("t5", "T5ForConditionalGeneration"), ] ) MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( [ ("speech-encoder-decoder", "SpeechEncoderDecoderModel"), ("speech_to_text", "Speech2TextForConditionalGeneration"), ] ) MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping ("roformer", "RoFormerForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"), ] ) MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping ("roformer", "RoFormerForQuestionAnswering"), ("longformer", "LongformerForQuestionAnswering"), ] ) MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Table Question Answering mapping ("tapas", "TapasForQuestionAnswering"), ] ) MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping ("roformer", "RoFormerForTokenClassification"), ("longformer", "LongformerForTokenClassification"), ] ) MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping ("roformer", "RoFormerForMultipleChoice"), ("longformer", "LongformerForMultipleChoice"), ] ) MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES ) MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES ) MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES ) MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES ) MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES ) MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) class AutoModel(_BaseAutoModelClass): _model_mapping = MODEL_MAPPING AutoModel = auto_class_update(AutoModel) class AutoModelForPreTraining(_BaseAutoModelClass): _model_mapping = MODEL_FOR_PRETRAINING_MAPPING AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") # Private on purpose, the public class will add the deprecation warnings. class _AutoModelWithLMHead(_BaseAutoModelClass): _model_mapping = MODEL_WITH_LM_HEAD_MAPPING _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") class AutoModelForCausalLM(_BaseAutoModelClass): _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") class AutoModelForMaskedLM(_BaseAutoModelClass): _model_mapping = MODEL_FOR_MASKED_LM_MAPPING AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") class AutoModelForSeq2SeqLM(_BaseAutoModelClass): _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING AutoModelForSeq2SeqLM = auto_class_update( AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" ) class AutoModelForSequenceClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING AutoModelForSequenceClassification = auto_class_update( AutoModelForSequenceClassification, head_doc="sequence classification" ) class AutoModelForQuestionAnswering(_BaseAutoModelClass): _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING AutoModelForTableQuestionAnswering = auto_class_update( AutoModelForTableQuestionAnswering, head_doc="table question answering", checkpoint_for_example="google/tapas-base-finetuned-wtq", ) class AutoModelForTokenClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") class AutoModelForMultipleChoice(_BaseAutoModelClass): _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING AutoModelForSpeechSeq2Seq = auto_class_update( AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing" ) class AutoModelWithLMHead(_AutoModelWithLMHead): @classmethod def from_config(cls, config): warnings.warn( "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) return super().from_config(config) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): warnings.warn( "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)