File size: 2,086 Bytes
58de15f
 
 
101cdda
58de15f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101cdda
58de15f
 
101cdda
 
58de15f
 
 
 
 
 
101cdda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58de15f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# expert_judge_model.py

import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer

class ExpertJudgeCrossEncoder(nn.Module):
    """
    The "Expert Judge" Cross-Encoder model.
    """
    def __init__(self, model_name='bert-base-uncased'):
        """
        Initializes the Cross-Encoder model.
        """
        super().__init__()
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=1
        )

    # --- CORRECTED: The forward method is now model-aware ---
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        """
        Forward pass for the cross-encoder. This version dynamically handles arguments
        to support different model architectures.
        
        Args:
            input_ids (Tensor): Token IDs for the concatenated sequence.
            attention_mask (Tensor): Attention mask for the input sequence.
            token_type_ids (Tensor, optional): Segment IDs to distinguish query from chunk. Defaults to None.
        """
        # 1. Create a dictionary with the arguments that are always required.
        model_inputs = {
            'input_ids': input_ids,
            'attention_mask': attention_mask
        }

        # 2. Check if the underlying model's forward method accepts 'token_type_ids'.
        #    This makes the class compatible with models that use it (like BERT) 
        #    and those that don't (like Qwen2).
        if 'token_type_ids' in self.model.forward.__code__.co_varnames:
            if token_type_ids is not None:
                model_inputs['token_type_ids'] = token_type_ids

        # 3. Pass the dynamically built arguments to the model using dictionary unpacking.
        outputs = self.model(**model_inputs)
        
        return outputs.logits

def get_tokenizer(model_name='bert-base-uncased'):
    """
    Helper function to get the tokenizer corresponding to the model.
    Using AutoTokenizer is generally safer.
    """
    return AutoTokenizer.from_pretrained(model_name)