# app/services/query_expansion_service.py import logging from typing import List import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import csv import re import os from typing import Dict import re # <-- NEW IMPORT for regular expressions # --- Core App Imports --- from app.core import state from app.core.config import settings logger = logging.getLogger(__name__) def replace_abbreviations(query_text: str) -> str: """ Expands a predefined set of abbreviations. Specifically, "AA" is only replaced if "Arrangement" is not already in the query. """ # 1. Define all replacement rules. replacements = { 'tph': 'Payment Hub', 'aa': 'Arrangement' # Add other unconditional replacements here. } # 2. Check if the word "Arrangement" is already in the query (case-insensitive). # We use a regex with \b to ensure we match the whole word. if re.search(r'\bArrangement\b', query_text, re.IGNORECASE): # If "Arrangement" is found, we don't want to replace "AA". # So, we remove the 'aa' rule from our dictionary for this run. del replacements['aa'] # 3. If there are no rules left to apply, return the original query. if not replacements: return query_text # 4. Build the regex pattern ONLY with the rules that are still active. pattern = re.compile( r'\b(' + '|'.join(replacements.keys()) + r')\b', re.IGNORECASE ) # 5. The replacer function remains the same. def get_replacement(match): return replacements[match.group(0).lower()] # 6. Perform the substitution and return the result. return pattern.sub(get_replacement, query_text)