feat: Add EuroBertForQuestionAnswering
Browse filesThis is copied directly from the [210m modeling script](https://huggingface.co/EuroBERT/EuroBERT-210m/edit/main/modeling_eurobert.py).
- modeling_eurobert.py +98 -1
 
    	
        modeling_eurobert.py
    CHANGED
    
    | 
         @@ -30,7 +30,7 @@ from transformers.activations import ACT2FN 
     | 
|
| 30 | 
         
             
            from transformers.cache_utils import Cache, StaticCache
         
     | 
| 31 | 
         
             
            from transformers.modeling_attn_mask_utils import AttentionMaskConverter
         
     | 
| 32 | 
         
             
            from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
         
     | 
| 33 | 
         
            -
            from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
         
     | 
| 34 | 
         
             
            from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
         
     | 
| 35 | 
         
             
            from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
         
     | 
| 36 | 
         
             
            from transformers.processing_utils import Unpack
         
     | 
| 
         @@ -951,10 +951,107 @@ class EuroBertForTokenClassification(EuroBertPreTrainedModel): 
     | 
|
| 951 | 
         
             
                    )
         
     | 
| 952 | 
         | 
| 953 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 954 | 
         
             
            __all__ = [
         
     | 
| 955 | 
         
             
                "EuroBertPreTrainedModel",
         
     | 
| 956 | 
         
             
                "EuroBertModel",
         
     | 
| 957 | 
         
             
                "EuroBertForMaskedLM",
         
     | 
| 958 | 
         
             
                "EuroBertForSequenceClassification",
         
     | 
| 959 | 
         
             
                "EuroBertForTokenClassification",
         
     | 
| 
         | 
|
| 960 | 
         
             
            ]
         
     | 
| 
         | 
|
| 30 | 
         
             
            from transformers.cache_utils import Cache, StaticCache
         
     | 
| 31 | 
         
             
            from transformers.modeling_attn_mask_utils import AttentionMaskConverter
         
     | 
| 32 | 
         
             
            from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
         
     | 
| 33 | 
         
            +
            from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
         
     | 
| 34 | 
         
             
            from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
         
     | 
| 35 | 
         
             
            from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
         
     | 
| 36 | 
         
             
            from transformers.processing_utils import Unpack
         
     | 
| 
         | 
|
| 951 | 
         
             
                    )
         
     | 
| 952 | 
         | 
| 953 | 
         | 
| 954 | 
         
            +
            @add_start_docstrings(
         
     | 
| 955 | 
         
            +
                """
         
     | 
| 956 | 
         
            +
                The EuroBert Model with a span classification head on top for extractive question-answering tasks 
         
     | 
| 957 | 
         
            +
                like SQuAD (a linear layers on top of the hidden-states output to compute span start logits 
         
     | 
| 958 | 
         
            +
                and span end logits).
         
     | 
| 959 | 
         
            +
                """,
         
     | 
| 960 | 
         
            +
                EUROBERT_START_DOCSTRING,
         
     | 
| 961 | 
         
            +
            )
         
     | 
| 962 | 
         
            +
            class EuroBertForQuestionAnswering(EuroBertPreTrainedModel):
         
     | 
| 963 | 
         
            +
                def __init__(self, config: EuroBertConfig):
         
     | 
| 964 | 
         
            +
                    super().__init__(config)
         
     | 
| 965 | 
         
            +
                    self.num_labels = config.num_labels
         
     | 
| 966 | 
         
            +
                    self.model = EuroBertModel(config)
         
     | 
| 967 | 
         
            +
             
     | 
| 968 | 
         
            +
                    self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
         
     | 
| 969 | 
         
            +
                    self.post_init()
         
     | 
| 970 | 
         
            +
             
     | 
| 971 | 
         
            +
                def get_input_embeddings(self):
         
     | 
| 972 | 
         
            +
                    return self.model.embed_tokens
         
     | 
| 973 | 
         
            +
             
     | 
| 974 | 
         
            +
                def set_input_embeddings(self, value):
         
     | 
| 975 | 
         
            +
                    self.model.embed_tokens = value
         
     | 
| 976 | 
         
            +
             
     | 
| 977 | 
         
            +
                @add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
         
     | 
| 978 | 
         
            +
                def forward(
         
     | 
| 979 | 
         
            +
                    self,
         
     | 
| 980 | 
         
            +
                    input_ids: Optional[torch.Tensor] = None,
         
     | 
| 981 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 982 | 
         
            +
                    position_ids: Optional[torch.Tensor] = None,
         
     | 
| 983 | 
         
            +
                    inputs_embeds: Optional[torch.Tensor] = None,
         
     | 
| 984 | 
         
            +
                    use_cache: Optional[bool] = None,
         
     | 
| 985 | 
         
            +
                    start_positions: Optional[torch.Tensor] = None,
         
     | 
| 986 | 
         
            +
                    end_positions: Optional[torch.Tensor] = None,
         
     | 
| 987 | 
         
            +
                    output_attentions: Optional[bool] = None,
         
     | 
| 988 | 
         
            +
                    output_hidden_states: Optional[bool] = None,
         
     | 
| 989 | 
         
            +
                    return_dict: Optional[bool] = None,
         
     | 
| 990 | 
         
            +
                ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
         
     | 
| 991 | 
         
            +
                    r"""
         
     | 
| 992 | 
         
            +
                    start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
         
     | 
| 993 | 
         
            +
                        Labels for position (index) of the start of the labelled span for computing the token classification loss.
         
     | 
| 994 | 
         
            +
                        Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
         
     | 
| 995 | 
         
            +
                        are not taken into account for computing the loss.
         
     | 
| 996 | 
         
            +
                    end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
         
     | 
| 997 | 
         
            +
                        Labels for position (index) of the end of the labelled span for computing the token classification loss.
         
     | 
| 998 | 
         
            +
                        Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
         
     | 
| 999 | 
         
            +
                        are not taken into account for computing the loss.
         
     | 
| 1000 | 
         
            +
                    """
         
     | 
| 1001 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 1002 | 
         
            +
             
     | 
| 1003 | 
         
            +
                    outputs = self.model(
         
     | 
| 1004 | 
         
            +
                        input_ids,
         
     | 
| 1005 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 1006 | 
         
            +
                        position_ids=position_ids,
         
     | 
| 1007 | 
         
            +
                        inputs_embeds=inputs_embeds,
         
     | 
| 1008 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 1009 | 
         
            +
                        output_attentions=output_attentions,
         
     | 
| 1010 | 
         
            +
                        output_hidden_states=output_hidden_states,
         
     | 
| 1011 | 
         
            +
                        return_dict=return_dict,
         
     | 
| 1012 | 
         
            +
                    )
         
     | 
| 1013 | 
         
            +
                    sequence_output = outputs[0]
         
     | 
| 1014 | 
         
            +
             
     | 
| 1015 | 
         
            +
                    logits = self.qa_outputs(sequence_output)
         
     | 
| 1016 | 
         
            +
                    start_logits, end_logits = logits.split(1, dim=-1)
         
     | 
| 1017 | 
         
            +
                    start_logits = start_logits.squeeze(-1).contiguous()
         
     | 
| 1018 | 
         
            +
                    end_logits = end_logits.squeeze(-1).contiguous()
         
     | 
| 1019 | 
         
            +
             
     | 
| 1020 | 
         
            +
                    total_loss = None
         
     | 
| 1021 | 
         
            +
                    if start_positions is not None and end_positions is not None:
         
     | 
| 1022 | 
         
            +
                        # If we are on multi-GPU, split add a dimension
         
     | 
| 1023 | 
         
            +
                        if len(start_positions.size()) > 1:
         
     | 
| 1024 | 
         
            +
                            start_positions = start_positions.squeeze(-1)
         
     | 
| 1025 | 
         
            +
                        if len(end_positions.size()) > 1:
         
     | 
| 1026 | 
         
            +
                            end_positions = end_positions.squeeze(-1)
         
     | 
| 1027 | 
         
            +
                        # sometimes the start/end positions are outside our model inputs, we ignore these terms
         
     | 
| 1028 | 
         
            +
                        ignored_index = start_logits.size(1)
         
     | 
| 1029 | 
         
            +
                        start_positions = start_positions.clamp(0, ignored_index)
         
     | 
| 1030 | 
         
            +
                        end_positions = end_positions.clamp(0, ignored_index)
         
     | 
| 1031 | 
         
            +
             
     | 
| 1032 | 
         
            +
                        loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
         
     | 
| 1033 | 
         
            +
                        start_loss = loss_fct(start_logits, start_positions)
         
     | 
| 1034 | 
         
            +
                        end_loss = loss_fct(end_logits, end_positions)
         
     | 
| 1035 | 
         
            +
                        total_loss = (start_loss + end_loss) / 2
         
     | 
| 1036 | 
         
            +
             
     | 
| 1037 | 
         
            +
                    if not return_dict:
         
     | 
| 1038 | 
         
            +
                        output = (start_logits, end_logits) + outputs[2:]
         
     | 
| 1039 | 
         
            +
                        return ((total_loss,) + output) if total_loss is not None else output
         
     | 
| 1040 | 
         
            +
             
     | 
| 1041 | 
         
            +
                    return QuestionAnsweringModelOutput(
         
     | 
| 1042 | 
         
            +
                        loss=total_loss,
         
     | 
| 1043 | 
         
            +
                        start_logits=start_logits,
         
     | 
| 1044 | 
         
            +
                        end_logits=end_logits,
         
     | 
| 1045 | 
         
            +
                        hidden_states=outputs.hidden_states,
         
     | 
| 1046 | 
         
            +
                        attentions=outputs.attentions,
         
     | 
| 1047 | 
         
            +
                    )
         
     | 
| 1048 | 
         
            +
             
     | 
| 1049 | 
         
            +
             
     | 
| 1050 | 
         
             
            __all__ = [
         
     | 
| 1051 | 
         
             
                "EuroBertPreTrainedModel",
         
     | 
| 1052 | 
         
             
                "EuroBertModel",
         
     | 
| 1053 | 
         
             
                "EuroBertForMaskedLM",
         
     | 
| 1054 | 
         
             
                "EuroBertForSequenceClassification",
         
     | 
| 1055 | 
         
             
                "EuroBertForTokenClassification",
         
     | 
| 1056 | 
         
            +
                "EuroBertForQuestionAnswering",
         
     | 
| 1057 | 
         
             
            ]
         
     |