backend_chatbot / app /services /query_expansion_service.py
helal94hb1's picture
fix: abbreviation AA2
8b1b13a
# app/services/query_expansion_service.py
import logging
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import csv
import re
import os
from typing import Dict
import re # <-- NEW IMPORT for regular expressions
# --- Core App Imports ---
from app.core import state
from app.core.config import settings
logger = logging.getLogger(__name__)
def replace_abbreviations(query_text: str) -> str:
"""
Expands a predefined set of abbreviations.
Specifically, "AA" is only replaced if "Arrangement" is not already in the query.
"""
# 1. Define all replacement rules.
replacements = {
'tph': 'Payment Hub',
'aa': 'Arrangement'
# Add other unconditional replacements here.
}
# 2. Check if the word "Arrangement" is already in the query (case-insensitive).
# We use a regex with \b to ensure we match the whole word.
if re.search(r'\bArrangement\b', query_text, re.IGNORECASE):
# If "Arrangement" is found, we don't want to replace "AA".
# So, we remove the 'aa' rule from our dictionary for this run.
del replacements['aa']
# 3. If there are no rules left to apply, return the original query.
if not replacements:
return query_text
# 4. Build the regex pattern ONLY with the rules that are still active.
pattern = re.compile(
r'\b(' + '|'.join(replacements.keys()) + r')\b',
re.IGNORECASE
)
# 5. The replacer function remains the same.
def get_replacement(match):
return replacements[match.group(0).lower()]
# 6. Perform the substitution and return the result.
return pattern.sub(get_replacement, query_text)