helal94hb1 commited on
Commit
8b1b13a
·
1 Parent(s): 67e765b

fix: abbreviation AA2

Browse files
app/api/v2_endpoints.py CHANGED
@@ -150,7 +150,7 @@ async def handle_v2_query(
150
  top_result_preview = None
151
  original_file = None
152
  try:
153
- # --- STEP 1: PRE-PROCESSING (Direct "TPH" Replacement) ---
154
  original_query = request.query
155
 
156
  # --- EDIT: Call the new, direct replacement function ---
 
150
  top_result_preview = None
151
  original_file = None
152
  try:
153
+ # --- STEP 1: PRE-PROCESSING (Direct ABBREVIATION Replacement) ---
154
  original_query = request.query
155
 
156
  # --- EDIT: Call the new, direct replacement function ---
app/services/query_expansion_service.py CHANGED
@@ -16,153 +16,6 @@ from app.core.config import settings
16
  logger = logging.getLogger(__name__)
17
 
18
 
19
- def load_t5_paraphraser():
20
- """
21
- Loads the T5 paraphrasing model and tokenizer into the central state.
22
- This should be called once on application startup.
23
- """
24
- if state.t5_paraphraser_loaded:
25
- logger.info("T5 paraphraser model already loaded in state.")
26
- return True
27
-
28
- # --- MODIFIED: Switched to a reliable, public T5 paraphrasing model ---
29
- model_name = getattr(settings, "T5_PARAPHRASER_MODEL_NAME", "humarin/chatgpt_paraphraser_on_T5_base")
30
- logger.info(f"Loading T5 paraphraser model: {model_name}...")
31
-
32
- try:
33
- state.t5_tokenizer = AutoTokenizer.from_pretrained(model_name)
34
- state.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
35
- state.t5_model.to(state.device)
36
- state.t5_model.eval()
37
- state.t5_paraphraser_loaded = True
38
- logger.info("T5 paraphraser model loaded successfully.")
39
- return True
40
- except Exception as e:
41
- logger.exception(f"Failed to load T5 paraphraser model: {e}")
42
- return False
43
-
44
- ### NEW: Function to load abbreviations at startup ###
45
- def load_abbreviations():
46
- """
47
- Loads the abbreviation mapping from a CSV file into the central state.
48
- """
49
- if state.abbreviations_loaded:
50
- logger.info("Abbreviation map already loaded in state.")
51
- return True
52
-
53
- file_path = settings.ABBREVIATION_FILE_PATH
54
- logger.info(f"Loading abbreviation map from: {file_path}")
55
-
56
- if not os.path.exists(file_path):
57
- logger.error(f"Abbreviation file not found at path: {file_path}")
58
- return False
59
-
60
- abbreviation_map = {}
61
- try:
62
- with open(file_path, mode='r', encoding='utf-8') as infile:
63
- reader = csv.reader(infile)
64
- # Skip header row
65
- next(reader, None)
66
- for row in reader:
67
- if len(row) >= 2:
68
- abbreviation = row[0].strip()
69
- original_text = row[1].strip()
70
- if abbreviation and original_text:
71
- # Store in lowercase for case-insensitive matching
72
- abbreviation_map[abbreviation.lower()] = original_text
73
-
74
- state.abbreviation_map = abbreviation_map
75
- state.abbreviations_loaded = True
76
- logger.info(f"Successfully loaded {len(abbreviation_map)} abbreviations.")
77
- return True
78
- except Exception as e:
79
- logger.exception(f"Failed to load or parse abbreviation file: {e}")
80
- return False
81
-
82
-
83
- ### NEW: Helper function to perform abbreviation expansion ###
84
- def _expand_with_abbreviations(query: str, abbrevation_map: Dict[str, str]) -> List[str]:
85
- """
86
- Generates new query variations by replacing known abbreviations.
87
- """
88
- expanded_queries = []
89
- # Use word boundaries to match whole words only, case-insensitively
90
- words = re.split(r'(\s+)', query)
91
-
92
- for i, word in enumerate(words):
93
- # Check the lowercased, punctuation-stripped word
94
- clean_word = re.sub(r'[^\w]', '', word).lower()
95
- if clean_word in abbrevation_map:
96
- # Create a new query with the replacement
97
- new_words = words[:]
98
- new_words[i] = abbrevation_map[clean_word]
99
- expanded_queries.append("".join(new_words))
100
-
101
- if expanded_queries:
102
- logger.info(f"Generated {len(expanded_queries)} variations from abbreviations.")
103
-
104
- return expanded_queries
105
-
106
-
107
- ### MODIFIED: Main function now combines both expansion strategies ###
108
- async def generate_query_variations(query: str, num_variations: int = 2) -> List[str]:
109
- """
110
- Uses both a local T5 model and an abbreviation map to generate paraphrases
111
- and expansions of the user's query.
112
- """
113
- all_variations = []
114
-
115
- # --- 1. T5 Paraphrasing ---
116
- if state.t5_paraphraser_loaded and state.t5_model and state.t5_tokenizer:
117
- try:
118
- input_text = query
119
- encoding = state.t5_tokenizer.encode_plus(input_text, padding="longest", return_tensors="pt")
120
- input_ids, attention_mask = encoding.input_ids.to(state.device), encoding.attention_mask.to(state.device)
121
-
122
- outputs = state.t5_model.generate(
123
- input_ids=input_ids,
124
- attention_mask=attention_mask,
125
- max_length=256,
126
- num_beams=10,
127
- num_return_sequences=num_variations,
128
- no_repeat_ngram_size=2,
129
- early_stopping=True
130
- )
131
-
132
- t5_variations = [
133
- state.t5_tokenizer.decode(seq, skip_special_tokens=True, clean_up_tokenization_spaces=True)
134
- for seq in outputs
135
- ]
136
- all_variations.extend(t5_variations)
137
- logger.info(f"Generated {len(t5_variations)} variations from T5 model.")
138
- except Exception as e:
139
- logger.exception("An error occurred during T5 query variation generation.")
140
- else:
141
- logger.warning("T5 paraphraser not loaded. Skipping AI paraphrasing.")
142
-
143
- # --- 2. Abbreviation Expansion ---
144
- if state.abbreviations_loaded and state.abbreviation_map:
145
- try:
146
- abbreviation_variations = _expand_with_abbreviations(query, state.abbreviation_map)
147
- all_variations.extend(abbreviation_variations)
148
- except Exception as e:
149
- logger.exception("An error occurred during abbreviation expansion.")
150
- else:
151
- logger.warning("Abbreviation map not loaded. Skipping abbreviation expansion.")
152
-
153
- # Return a unique list of variations
154
- return list(set(all_variations))
155
- # --- NEW: Simple, direct function for TPH replacement ---
156
- def expand_tph_in_query(query_text: str) -> str:
157
- """
158
- Performs a case-insensitive, whole-word replacement of "TPH" with "Payment Hub".
159
- """
160
- # \b ensures we match "TPH" as a whole word, not as part of another word like "GRAPH".
161
- # re.IGNORECASE makes the match case-insensitive (e.g., "tph", "Tph").
162
- pattern = r'\bTPH\b'
163
- replacement = "Payment Hub"
164
-
165
- return re.sub(pattern, replacement, query_text, flags=re.IGNORECASE)
166
 
167
  def replace_abbreviations(query_text: str) -> str:
168
  """
@@ -199,53 +52,3 @@ def replace_abbreviations(query_text: str) -> str:
199
 
200
  # 6. Perform the substitution and return the result.
201
  return pattern.sub(get_replacement, query_text)
202
- # async def generate_query_variations(query: str, num_variations: int = 2) -> List[str]:
203
- # """
204
- # Uses a local T5 model to generate paraphrases of the user's query.
205
-
206
- # Args:
207
- # query (str): The original user query.
208
- # num_variations (int): The number of variations to generate.
209
-
210
- # Returns:
211
- # List[str]: A list of paraphrased queries. Returns an empty list on failure.
212
- # """
213
- # if not state.t5_paraphraser_loaded or not state.t5_model or not state.t5_tokenizer:
214
- # logger.error("Cannot generate query variations: T5 paraphraser is not initialized.")
215
- # return []
216
-
217
- # try:
218
- # # --- MODIFIED: Removed the "paraphrase: " prefix as this model does not require it ---
219
- # input_text = query
220
-
221
- # # Tokenize the input
222
- # encoding = state.t5_tokenizer.encode_plus(
223
- # input_text,
224
- # padding="longest",
225
- # return_tensors="pt"
226
- # )
227
- # input_ids, attention_mask = encoding.input_ids.to(state.device), encoding.attention_mask.to(state.device)
228
-
229
- # # Generate variations
230
- # outputs = state.t5_model.generate(
231
- # input_ids=input_ids,
232
- # attention_mask=attention_mask,
233
- # max_length=256,
234
- # num_beams=10,
235
- # num_return_sequences=num_variations,
236
- # no_repeat_ngram_size=2,
237
- # early_stopping=True
238
- # )
239
-
240
- # # Decode the generated token IDs back to strings
241
- # variations = [
242
- # state.t5_tokenizer.decode(generated_sequence, skip_special_tokens=True, clean_up_tokenization_spaces=True)
243
- # for generated_sequence in outputs
244
- # ]
245
-
246
- # logger.info(f"Generated {len(variations)} variations for query.")
247
- # return variations
248
-
249
- # except Exception as e:
250
- # logger.exception(f"An unexpected error occurred during T5 query variation generation: {e}")
251
- # return []
 
16
  logger = logging.getLogger(__name__)
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def replace_abbreviations(query_text: str) -> str:
21
  """
 
52
 
53
  # 6. Perform the substitution and return the result.
54
  return pattern.sub(get_replacement, query_text)