backend_chatbot / app /models /expert_judge_model.py
helal94hb1's picture
fix: new embeddings and rerankh4
101cdda
# expert_judge_model.py
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer
class ExpertJudgeCrossEncoder(nn.Module):
"""
The "Expert Judge" Cross-Encoder model.
"""
def __init__(self, model_name='bert-base-uncased'):
"""
Initializes the Cross-Encoder model.
"""
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=1
)
# --- CORRECTED: The forward method is now model-aware ---
def forward(self, input_ids, attention_mask, token_type_ids=None):
"""
Forward pass for the cross-encoder. This version dynamically handles arguments
to support different model architectures.
Args:
input_ids (Tensor): Token IDs for the concatenated sequence.
attention_mask (Tensor): Attention mask for the input sequence.
token_type_ids (Tensor, optional): Segment IDs to distinguish query from chunk. Defaults to None.
"""
# 1. Create a dictionary with the arguments that are always required.
model_inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask
}
# 2. Check if the underlying model's forward method accepts 'token_type_ids'.
# This makes the class compatible with models that use it (like BERT)
# and those that don't (like Qwen2).
if 'token_type_ids' in self.model.forward.__code__.co_varnames:
if token_type_ids is not None:
model_inputs['token_type_ids'] = token_type_ids
# 3. Pass the dynamically built arguments to the model using dictionary unpacking.
outputs = self.model(**model_inputs)
return outputs.logits
def get_tokenizer(model_name='bert-base-uncased'):
"""
Helper function to get the tokenizer corresponding to the model.
Using AutoTokenizer is generally safer.
"""
return AutoTokenizer.from_pretrained(model_name)