helal94hb1 commited on
Commit
101cdda
·
1 Parent(s): ffe6d33

fix: new embeddings and rerankh4

Browse files
Files changed (1) hide show
  1. app/models/expert_judge_model.py +20 -16
app/models/expert_judge_model.py CHANGED
@@ -1,8 +1,7 @@
1
  # expert_judge_model.py
2
 
3
  import torch.nn as nn
4
- # --- MODIFIED: Import AutoTokenizer for a more robust helper ---
5
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, BertTokenizer
6
 
7
  class ExpertJudgeCrossEncoder(nn.Module):
8
  """
@@ -18,30 +17,35 @@ class ExpertJudgeCrossEncoder(nn.Module):
18
  num_labels=1
19
  )
20
 
21
-
22
-
23
-
24
-
25
- # --- MODIFIED: The forward method now accepts token_type_ids ---
26
  def forward(self, input_ids, attention_mask, token_type_ids=None):
27
  """
28
- Forward pass for the cross-encoder.
 
29
 
30
  Args:
31
  input_ids (Tensor): Token IDs for the concatenated sequence.
32
  attention_mask (Tensor): Attention mask for the input sequence.
33
  token_type_ids (Tensor, optional): Segment IDs to distinguish query from chunk. Defaults to None.
34
  """
35
- # Pass all arguments to the underlying Hugging Face model.
36
- # It will use token_type_ids if the model architecture supports it (like BERT).
37
- outputs = self.model(
38
- input_ids=input_ids,
39
- attention_mask=attention_mask,
40
- token_type_ids=token_type_ids
41
- )
 
 
 
 
 
 
 
 
 
42
  return outputs.logits
43
 
44
- # --- MODIFIED: A more robust tokenizer helper function ---
45
  def get_tokenizer(model_name='bert-base-uncased'):
46
  """
47
  Helper function to get the tokenizer corresponding to the model.
 
1
  # expert_judge_model.py
2
 
3
  import torch.nn as nn
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
5
 
6
  class ExpertJudgeCrossEncoder(nn.Module):
7
  """
 
17
  num_labels=1
18
  )
19
 
20
+ # --- CORRECTED: The forward method is now model-aware ---
 
 
 
 
21
  def forward(self, input_ids, attention_mask, token_type_ids=None):
22
  """
23
+ Forward pass for the cross-encoder. This version dynamically handles arguments
24
+ to support different model architectures.
25
 
26
  Args:
27
  input_ids (Tensor): Token IDs for the concatenated sequence.
28
  attention_mask (Tensor): Attention mask for the input sequence.
29
  token_type_ids (Tensor, optional): Segment IDs to distinguish query from chunk. Defaults to None.
30
  """
31
+ # 1. Create a dictionary with the arguments that are always required.
32
+ model_inputs = {
33
+ 'input_ids': input_ids,
34
+ 'attention_mask': attention_mask
35
+ }
36
+
37
+ # 2. Check if the underlying model's forward method accepts 'token_type_ids'.
38
+ # This makes the class compatible with models that use it (like BERT)
39
+ # and those that don't (like Qwen2).
40
+ if 'token_type_ids' in self.model.forward.__code__.co_varnames:
41
+ if token_type_ids is not None:
42
+ model_inputs['token_type_ids'] = token_type_ids
43
+
44
+ # 3. Pass the dynamically built arguments to the model using dictionary unpacking.
45
+ outputs = self.model(**model_inputs)
46
+
47
  return outputs.logits
48
 
 
49
  def get_tokenizer(model_name='bert-base-uncased'):
50
  """
51
  Helper function to get the tokenizer corresponding to the model.