Spaces:
Sleeping
Sleeping
Commit
·
8b1b13a
1
Parent(s):
67e765b
fix: abbreviation AA2
Browse files
app/api/v2_endpoints.py
CHANGED
|
@@ -150,7 +150,7 @@ async def handle_v2_query(
|
|
| 150 |
top_result_preview = None
|
| 151 |
original_file = None
|
| 152 |
try:
|
| 153 |
-
# --- STEP 1: PRE-PROCESSING (Direct
|
| 154 |
original_query = request.query
|
| 155 |
|
| 156 |
# --- EDIT: Call the new, direct replacement function ---
|
|
|
|
| 150 |
top_result_preview = None
|
| 151 |
original_file = None
|
| 152 |
try:
|
| 153 |
+
# --- STEP 1: PRE-PROCESSING (Direct ABBREVIATION Replacement) ---
|
| 154 |
original_query = request.query
|
| 155 |
|
| 156 |
# --- EDIT: Call the new, direct replacement function ---
|
app/services/query_expansion_service.py
CHANGED
|
@@ -16,153 +16,6 @@ from app.core.config import settings
|
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
|
| 19 |
-
def load_t5_paraphraser():
|
| 20 |
-
"""
|
| 21 |
-
Loads the T5 paraphrasing model and tokenizer into the central state.
|
| 22 |
-
This should be called once on application startup.
|
| 23 |
-
"""
|
| 24 |
-
if state.t5_paraphraser_loaded:
|
| 25 |
-
logger.info("T5 paraphraser model already loaded in state.")
|
| 26 |
-
return True
|
| 27 |
-
|
| 28 |
-
# --- MODIFIED: Switched to a reliable, public T5 paraphrasing model ---
|
| 29 |
-
model_name = getattr(settings, "T5_PARAPHRASER_MODEL_NAME", "humarin/chatgpt_paraphraser_on_T5_base")
|
| 30 |
-
logger.info(f"Loading T5 paraphraser model: {model_name}...")
|
| 31 |
-
|
| 32 |
-
try:
|
| 33 |
-
state.t5_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 34 |
-
state.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 35 |
-
state.t5_model.to(state.device)
|
| 36 |
-
state.t5_model.eval()
|
| 37 |
-
state.t5_paraphraser_loaded = True
|
| 38 |
-
logger.info("T5 paraphraser model loaded successfully.")
|
| 39 |
-
return True
|
| 40 |
-
except Exception as e:
|
| 41 |
-
logger.exception(f"Failed to load T5 paraphraser model: {e}")
|
| 42 |
-
return False
|
| 43 |
-
|
| 44 |
-
### NEW: Function to load abbreviations at startup ###
|
| 45 |
-
def load_abbreviations():
|
| 46 |
-
"""
|
| 47 |
-
Loads the abbreviation mapping from a CSV file into the central state.
|
| 48 |
-
"""
|
| 49 |
-
if state.abbreviations_loaded:
|
| 50 |
-
logger.info("Abbreviation map already loaded in state.")
|
| 51 |
-
return True
|
| 52 |
-
|
| 53 |
-
file_path = settings.ABBREVIATION_FILE_PATH
|
| 54 |
-
logger.info(f"Loading abbreviation map from: {file_path}")
|
| 55 |
-
|
| 56 |
-
if not os.path.exists(file_path):
|
| 57 |
-
logger.error(f"Abbreviation file not found at path: {file_path}")
|
| 58 |
-
return False
|
| 59 |
-
|
| 60 |
-
abbreviation_map = {}
|
| 61 |
-
try:
|
| 62 |
-
with open(file_path, mode='r', encoding='utf-8') as infile:
|
| 63 |
-
reader = csv.reader(infile)
|
| 64 |
-
# Skip header row
|
| 65 |
-
next(reader, None)
|
| 66 |
-
for row in reader:
|
| 67 |
-
if len(row) >= 2:
|
| 68 |
-
abbreviation = row[0].strip()
|
| 69 |
-
original_text = row[1].strip()
|
| 70 |
-
if abbreviation and original_text:
|
| 71 |
-
# Store in lowercase for case-insensitive matching
|
| 72 |
-
abbreviation_map[abbreviation.lower()] = original_text
|
| 73 |
-
|
| 74 |
-
state.abbreviation_map = abbreviation_map
|
| 75 |
-
state.abbreviations_loaded = True
|
| 76 |
-
logger.info(f"Successfully loaded {len(abbreviation_map)} abbreviations.")
|
| 77 |
-
return True
|
| 78 |
-
except Exception as e:
|
| 79 |
-
logger.exception(f"Failed to load or parse abbreviation file: {e}")
|
| 80 |
-
return False
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
### NEW: Helper function to perform abbreviation expansion ###
|
| 84 |
-
def _expand_with_abbreviations(query: str, abbrevation_map: Dict[str, str]) -> List[str]:
|
| 85 |
-
"""
|
| 86 |
-
Generates new query variations by replacing known abbreviations.
|
| 87 |
-
"""
|
| 88 |
-
expanded_queries = []
|
| 89 |
-
# Use word boundaries to match whole words only, case-insensitively
|
| 90 |
-
words = re.split(r'(\s+)', query)
|
| 91 |
-
|
| 92 |
-
for i, word in enumerate(words):
|
| 93 |
-
# Check the lowercased, punctuation-stripped word
|
| 94 |
-
clean_word = re.sub(r'[^\w]', '', word).lower()
|
| 95 |
-
if clean_word in abbrevation_map:
|
| 96 |
-
# Create a new query with the replacement
|
| 97 |
-
new_words = words[:]
|
| 98 |
-
new_words[i] = abbrevation_map[clean_word]
|
| 99 |
-
expanded_queries.append("".join(new_words))
|
| 100 |
-
|
| 101 |
-
if expanded_queries:
|
| 102 |
-
logger.info(f"Generated {len(expanded_queries)} variations from abbreviations.")
|
| 103 |
-
|
| 104 |
-
return expanded_queries
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
### MODIFIED: Main function now combines both expansion strategies ###
|
| 108 |
-
async def generate_query_variations(query: str, num_variations: int = 2) -> List[str]:
|
| 109 |
-
"""
|
| 110 |
-
Uses both a local T5 model and an abbreviation map to generate paraphrases
|
| 111 |
-
and expansions of the user's query.
|
| 112 |
-
"""
|
| 113 |
-
all_variations = []
|
| 114 |
-
|
| 115 |
-
# --- 1. T5 Paraphrasing ---
|
| 116 |
-
if state.t5_paraphraser_loaded and state.t5_model and state.t5_tokenizer:
|
| 117 |
-
try:
|
| 118 |
-
input_text = query
|
| 119 |
-
encoding = state.t5_tokenizer.encode_plus(input_text, padding="longest", return_tensors="pt")
|
| 120 |
-
input_ids, attention_mask = encoding.input_ids.to(state.device), encoding.attention_mask.to(state.device)
|
| 121 |
-
|
| 122 |
-
outputs = state.t5_model.generate(
|
| 123 |
-
input_ids=input_ids,
|
| 124 |
-
attention_mask=attention_mask,
|
| 125 |
-
max_length=256,
|
| 126 |
-
num_beams=10,
|
| 127 |
-
num_return_sequences=num_variations,
|
| 128 |
-
no_repeat_ngram_size=2,
|
| 129 |
-
early_stopping=True
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
t5_variations = [
|
| 133 |
-
state.t5_tokenizer.decode(seq, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
| 134 |
-
for seq in outputs
|
| 135 |
-
]
|
| 136 |
-
all_variations.extend(t5_variations)
|
| 137 |
-
logger.info(f"Generated {len(t5_variations)} variations from T5 model.")
|
| 138 |
-
except Exception as e:
|
| 139 |
-
logger.exception("An error occurred during T5 query variation generation.")
|
| 140 |
-
else:
|
| 141 |
-
logger.warning("T5 paraphraser not loaded. Skipping AI paraphrasing.")
|
| 142 |
-
|
| 143 |
-
# --- 2. Abbreviation Expansion ---
|
| 144 |
-
if state.abbreviations_loaded and state.abbreviation_map:
|
| 145 |
-
try:
|
| 146 |
-
abbreviation_variations = _expand_with_abbreviations(query, state.abbreviation_map)
|
| 147 |
-
all_variations.extend(abbreviation_variations)
|
| 148 |
-
except Exception as e:
|
| 149 |
-
logger.exception("An error occurred during abbreviation expansion.")
|
| 150 |
-
else:
|
| 151 |
-
logger.warning("Abbreviation map not loaded. Skipping abbreviation expansion.")
|
| 152 |
-
|
| 153 |
-
# Return a unique list of variations
|
| 154 |
-
return list(set(all_variations))
|
| 155 |
-
# --- NEW: Simple, direct function for TPH replacement ---
|
| 156 |
-
def expand_tph_in_query(query_text: str) -> str:
|
| 157 |
-
"""
|
| 158 |
-
Performs a case-insensitive, whole-word replacement of "TPH" with "Payment Hub".
|
| 159 |
-
"""
|
| 160 |
-
# \b ensures we match "TPH" as a whole word, not as part of another word like "GRAPH".
|
| 161 |
-
# re.IGNORECASE makes the match case-insensitive (e.g., "tph", "Tph").
|
| 162 |
-
pattern = r'\bTPH\b'
|
| 163 |
-
replacement = "Payment Hub"
|
| 164 |
-
|
| 165 |
-
return re.sub(pattern, replacement, query_text, flags=re.IGNORECASE)
|
| 166 |
|
| 167 |
def replace_abbreviations(query_text: str) -> str:
|
| 168 |
"""
|
|
@@ -199,53 +52,3 @@ def replace_abbreviations(query_text: str) -> str:
|
|
| 199 |
|
| 200 |
# 6. Perform the substitution and return the result.
|
| 201 |
return pattern.sub(get_replacement, query_text)
|
| 202 |
-
# async def generate_query_variations(query: str, num_variations: int = 2) -> List[str]:
|
| 203 |
-
# """
|
| 204 |
-
# Uses a local T5 model to generate paraphrases of the user's query.
|
| 205 |
-
|
| 206 |
-
# Args:
|
| 207 |
-
# query (str): The original user query.
|
| 208 |
-
# num_variations (int): The number of variations to generate.
|
| 209 |
-
|
| 210 |
-
# Returns:
|
| 211 |
-
# List[str]: A list of paraphrased queries. Returns an empty list on failure.
|
| 212 |
-
# """
|
| 213 |
-
# if not state.t5_paraphraser_loaded or not state.t5_model or not state.t5_tokenizer:
|
| 214 |
-
# logger.error("Cannot generate query variations: T5 paraphraser is not initialized.")
|
| 215 |
-
# return []
|
| 216 |
-
|
| 217 |
-
# try:
|
| 218 |
-
# # --- MODIFIED: Removed the "paraphrase: " prefix as this model does not require it ---
|
| 219 |
-
# input_text = query
|
| 220 |
-
|
| 221 |
-
# # Tokenize the input
|
| 222 |
-
# encoding = state.t5_tokenizer.encode_plus(
|
| 223 |
-
# input_text,
|
| 224 |
-
# padding="longest",
|
| 225 |
-
# return_tensors="pt"
|
| 226 |
-
# )
|
| 227 |
-
# input_ids, attention_mask = encoding.input_ids.to(state.device), encoding.attention_mask.to(state.device)
|
| 228 |
-
|
| 229 |
-
# # Generate variations
|
| 230 |
-
# outputs = state.t5_model.generate(
|
| 231 |
-
# input_ids=input_ids,
|
| 232 |
-
# attention_mask=attention_mask,
|
| 233 |
-
# max_length=256,
|
| 234 |
-
# num_beams=10,
|
| 235 |
-
# num_return_sequences=num_variations,
|
| 236 |
-
# no_repeat_ngram_size=2,
|
| 237 |
-
# early_stopping=True
|
| 238 |
-
# )
|
| 239 |
-
|
| 240 |
-
# # Decode the generated token IDs back to strings
|
| 241 |
-
# variations = [
|
| 242 |
-
# state.t5_tokenizer.decode(generated_sequence, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
| 243 |
-
# for generated_sequence in outputs
|
| 244 |
-
# ]
|
| 245 |
-
|
| 246 |
-
# logger.info(f"Generated {len(variations)} variations for query.")
|
| 247 |
-
# return variations
|
| 248 |
-
|
| 249 |
-
# except Exception as e:
|
| 250 |
-
# logger.exception(f"An unexpected error occurred during T5 query variation generation: {e}")
|
| 251 |
-
# return []
|
|
|
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def replace_abbreviations(query_text: str) -> str:
|
| 21 |
"""
|
|
|
|
| 52 |
|
| 53 |
# 6. Perform the substitution and return the result.
|
| 54 |
return pattern.sub(get_replacement, query_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|