HaloMaster's picture
add fengshen
50f0fbb
raw
history blame
8.52 kB
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class."""
import warnings
from collections import OrderedDict
from transformers.utils import logging
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("roformer", "RoFormerModel"),
("longformer", "LongformerModel"),
]
)
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
("longformer", "LongformerForMaskedLM"),
]
)
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("roformer", "RoFormerForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
]
)
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("roformer", "RoFormerForCausalLM"),
]
)
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("roformer", "RoFormerForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
]
)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
("t5", "T5ForConditionalGeneration"),
]
)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("roformer", "RoFormerForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"),
]
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("roformer", "RoFormerForQuestionAnswering"),
("longformer", "LongformerForQuestionAnswering"),
]
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Table Question Answering mapping
("tapas", "TapasForQuestionAnswering"),
]
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("roformer", "RoFormerForTokenClassification"),
("longformer", "LongformerForTokenClassification"),
]
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("roformer", "RoFormerForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"),
]
)
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
class AutoModel(_BaseAutoModelClass):
_model_mapping = MODEL_MAPPING
AutoModel = auto_class_update(AutoModel)
class AutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING
AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
# Private on purpose, the public class will add the deprecation warnings.
class _AutoModelWithLMHead(_BaseAutoModelClass):
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
class AutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
AutoModelForSeq2SeqLM = auto_class_update(
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
)
class AutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
AutoModelForSequenceClassification = auto_class_update(
AutoModelForSequenceClassification, head_doc="sequence classification"
)
class AutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
AutoModelForTableQuestionAnswering = auto_class_update(
AutoModelForTableQuestionAnswering,
head_doc="table question answering",
checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
class AutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
class AutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
AutoModelForSpeechSeq2Seq = auto_class_update(
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing"
)
class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod
def from_config(cls, config):
warnings.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_config(config)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)