# expert_judge_model.py import torch.nn as nn from transformers import AutoModelForSequenceClassification, AutoTokenizer class ExpertJudgeCrossEncoder(nn.Module): """ The "Expert Judge" Cross-Encoder model. """ def __init__(self, model_name='bert-base-uncased'): """ Initializes the Cross-Encoder model. """ super().__init__() self.model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=1 ) # --- CORRECTED: The forward method is now model-aware --- def forward(self, input_ids, attention_mask, token_type_ids=None): """ Forward pass for the cross-encoder. This version dynamically handles arguments to support different model architectures. Args: input_ids (Tensor): Token IDs for the concatenated sequence. attention_mask (Tensor): Attention mask for the input sequence. token_type_ids (Tensor, optional): Segment IDs to distinguish query from chunk. Defaults to None. """ # 1. Create a dictionary with the arguments that are always required. model_inputs = { 'input_ids': input_ids, 'attention_mask': attention_mask } # 2. Check if the underlying model's forward method accepts 'token_type_ids'. # This makes the class compatible with models that use it (like BERT) # and those that don't (like Qwen2). if 'token_type_ids' in self.model.forward.__code__.co_varnames: if token_type_ids is not None: model_inputs['token_type_ids'] = token_type_ids # 3. Pass the dynamically built arguments to the model using dictionary unpacking. outputs = self.model(**model_inputs) return outputs.logits def get_tokenizer(model_name='bert-base-uncased'): """ Helper function to get the tokenizer corresponding to the model. Using AutoTokenizer is generally safer. """ return AutoTokenizer.from_pretrained(model_name)