Spaces:
Sleeping
Sleeping
Commit
·
101cdda
1
Parent(s):
ffe6d33
fix: new embeddings and rerankh4
Browse files- app/models/expert_judge_model.py +20 -16
app/models/expert_judge_model.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
# expert_judge_model.py
|
| 2 |
|
| 3 |
import torch.nn as nn
|
| 4 |
-
|
| 5 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BertTokenizer
|
| 6 |
|
| 7 |
class ExpertJudgeCrossEncoder(nn.Module):
|
| 8 |
"""
|
|
@@ -18,30 +17,35 @@ class ExpertJudgeCrossEncoder(nn.Module):
|
|
| 18 |
num_labels=1
|
| 19 |
)
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# --- MODIFIED: The forward method now accepts token_type_ids ---
|
| 26 |
def forward(self, input_ids, attention_mask, token_type_ids=None):
|
| 27 |
"""
|
| 28 |
-
Forward pass for the cross-encoder.
|
|
|
|
| 29 |
|
| 30 |
Args:
|
| 31 |
input_ids (Tensor): Token IDs for the concatenated sequence.
|
| 32 |
attention_mask (Tensor): Attention mask for the input sequence.
|
| 33 |
token_type_ids (Tensor, optional): Segment IDs to distinguish query from chunk. Defaults to None.
|
| 34 |
"""
|
| 35 |
-
#
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
return outputs.logits
|
| 43 |
|
| 44 |
-
# --- MODIFIED: A more robust tokenizer helper function ---
|
| 45 |
def get_tokenizer(model_name='bert-base-uncased'):
|
| 46 |
"""
|
| 47 |
Helper function to get the tokenizer corresponding to the model.
|
|
|
|
| 1 |
# expert_judge_model.py
|
| 2 |
|
| 3 |
import torch.nn as nn
|
| 4 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
|
|
| 5 |
|
| 6 |
class ExpertJudgeCrossEncoder(nn.Module):
|
| 7 |
"""
|
|
|
|
| 17 |
num_labels=1
|
| 18 |
)
|
| 19 |
|
| 20 |
+
# --- CORRECTED: The forward method is now model-aware ---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def forward(self, input_ids, attention_mask, token_type_ids=None):
|
| 22 |
"""
|
| 23 |
+
Forward pass for the cross-encoder. This version dynamically handles arguments
|
| 24 |
+
to support different model architectures.
|
| 25 |
|
| 26 |
Args:
|
| 27 |
input_ids (Tensor): Token IDs for the concatenated sequence.
|
| 28 |
attention_mask (Tensor): Attention mask for the input sequence.
|
| 29 |
token_type_ids (Tensor, optional): Segment IDs to distinguish query from chunk. Defaults to None.
|
| 30 |
"""
|
| 31 |
+
# 1. Create a dictionary with the arguments that are always required.
|
| 32 |
+
model_inputs = {
|
| 33 |
+
'input_ids': input_ids,
|
| 34 |
+
'attention_mask': attention_mask
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
# 2. Check if the underlying model's forward method accepts 'token_type_ids'.
|
| 38 |
+
# This makes the class compatible with models that use it (like BERT)
|
| 39 |
+
# and those that don't (like Qwen2).
|
| 40 |
+
if 'token_type_ids' in self.model.forward.__code__.co_varnames:
|
| 41 |
+
if token_type_ids is not None:
|
| 42 |
+
model_inputs['token_type_ids'] = token_type_ids
|
| 43 |
+
|
| 44 |
+
# 3. Pass the dynamically built arguments to the model using dictionary unpacking.
|
| 45 |
+
outputs = self.model(**model_inputs)
|
| 46 |
+
|
| 47 |
return outputs.logits
|
| 48 |
|
|
|
|
| 49 |
def get_tokenizer(model_name='bert-base-uncased'):
|
| 50 |
"""
|
| 51 |
Helper function to get the tokenizer corresponding to the model.
|