Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2018 The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ Auto Model class.""" | |
| import warnings | |
| from collections import OrderedDict | |
| from transformers.utils import logging | |
| from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update | |
| from .configuration_auto import CONFIG_MAPPING_NAMES | |
| logger = logging.get_logger(__name__) | |
| MODEL_MAPPING_NAMES = OrderedDict( | |
| [ | |
| # Base model mapping | |
| ("roformer", "RoFormerModel"), | |
| ("longformer", "LongformerModel"), | |
| ] | |
| ) | |
| MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( | |
| [ | |
| # Model for pre-training mapping | |
| ("longformer", "LongformerForMaskedLM"), | |
| ] | |
| ) | |
| MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( | |
| [ | |
| # Model with LM heads mapping | |
| ("roformer", "RoFormerForMaskedLM"), | |
| ("longformer", "LongformerForMaskedLM"), | |
| ] | |
| ) | |
| MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |
| [ | |
| # Model for Causal LM mapping | |
| ("roformer", "RoFormerForCausalLM"), | |
| ] | |
| ) | |
| MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( | |
| [ | |
| # Model for Masked LM mapping | |
| ("roformer", "RoFormerForMaskedLM"), | |
| ("longformer", "LongformerForMaskedLM"), | |
| ] | |
| ) | |
| MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |
| [ | |
| # Model for Seq2Seq Causal LM mapping | |
| ("t5", "T5ForConditionalGeneration"), | |
| ] | |
| ) | |
| MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( | |
| [ | |
| ("speech-encoder-decoder", "SpeechEncoderDecoderModel"), | |
| ("speech_to_text", "Speech2TextForConditionalGeneration"), | |
| ] | |
| ) | |
| MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |
| [ | |
| # Model for Sequence Classification mapping | |
| ("roformer", "RoFormerForSequenceClassification"), | |
| ("longformer", "LongformerForSequenceClassification"), | |
| ] | |
| ) | |
| MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( | |
| [ | |
| # Model for Question Answering mapping | |
| ("roformer", "RoFormerForQuestionAnswering"), | |
| ("longformer", "LongformerForQuestionAnswering"), | |
| ] | |
| ) | |
| MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( | |
| [ | |
| # Model for Table Question Answering mapping | |
| ("tapas", "TapasForQuestionAnswering"), | |
| ] | |
| ) | |
| MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( | |
| [ | |
| # Model for Token Classification mapping | |
| ("roformer", "RoFormerForTokenClassification"), | |
| ("longformer", "LongformerForTokenClassification"), | |
| ] | |
| ) | |
| MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( | |
| [ | |
| # Model for Multiple Choice mapping | |
| ("roformer", "RoFormerForMultipleChoice"), | |
| ("longformer", "LongformerForMultipleChoice"), | |
| ] | |
| ) | |
| MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) | |
| MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) | |
| MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) | |
| MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) | |
| MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) | |
| MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( | |
| CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES | |
| ) | |
| MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |
| CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES | |
| ) | |
| MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |
| CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES | |
| ) | |
| MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( | |
| CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES | |
| ) | |
| MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( | |
| CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES | |
| ) | |
| MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) | |
| MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) | |
| class AutoModel(_BaseAutoModelClass): | |
| _model_mapping = MODEL_MAPPING | |
| AutoModel = auto_class_update(AutoModel) | |
| class AutoModelForPreTraining(_BaseAutoModelClass): | |
| _model_mapping = MODEL_FOR_PRETRAINING_MAPPING | |
| AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") | |
| # Private on purpose, the public class will add the deprecation warnings. | |
| class _AutoModelWithLMHead(_BaseAutoModelClass): | |
| _model_mapping = MODEL_WITH_LM_HEAD_MAPPING | |
| _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") | |
| class AutoModelForCausalLM(_BaseAutoModelClass): | |
| _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING | |
| AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") | |
| class AutoModelForMaskedLM(_BaseAutoModelClass): | |
| _model_mapping = MODEL_FOR_MASKED_LM_MAPPING | |
| AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") | |
| class AutoModelForSeq2SeqLM(_BaseAutoModelClass): | |
| _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING | |
| AutoModelForSeq2SeqLM = auto_class_update( | |
| AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" | |
| ) | |
| class AutoModelForSequenceClassification(_BaseAutoModelClass): | |
| _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING | |
| AutoModelForSequenceClassification = auto_class_update( | |
| AutoModelForSequenceClassification, head_doc="sequence classification" | |
| ) | |
| class AutoModelForQuestionAnswering(_BaseAutoModelClass): | |
| _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING | |
| AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") | |
| class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): | |
| _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING | |
| AutoModelForTableQuestionAnswering = auto_class_update( | |
| AutoModelForTableQuestionAnswering, | |
| head_doc="table question answering", | |
| checkpoint_for_example="google/tapas-base-finetuned-wtq", | |
| ) | |
| class AutoModelForTokenClassification(_BaseAutoModelClass): | |
| _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |
| AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") | |
| class AutoModelForMultipleChoice(_BaseAutoModelClass): | |
| _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING | |
| AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") | |
| class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): | |
| _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING | |
| AutoModelForSpeechSeq2Seq = auto_class_update( | |
| AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing" | |
| ) | |
| class AutoModelWithLMHead(_AutoModelWithLMHead): | |
| def from_config(cls, config): | |
| warnings.warn( | |
| "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |
| "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |
| "`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |
| FutureWarning, | |
| ) | |
| return super().from_config(config) | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| warnings.warn( | |
| "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |
| "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |
| "`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |
| FutureWarning, | |
| ) | |
| return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | |