mtyrrell commited on
Commit
8f0a9cd
·
1 Parent(s): f2a3674

ts citation filtering

Browse files
Files changed (1) hide show
  1. utils/generator.py +36 -6
utils/generator.py CHANGED
@@ -16,6 +16,9 @@ from langchain_core.messages import SystemMessage, HumanMessage
16
  # Local imports
17
  from .utils import getconfig, get_auth
18
 
 
 
 
19
  # ---------------------------------------------------------------------
20
  # Configuration and Model Initialization
21
  # ---------------------------------------------------------------------
@@ -57,18 +60,45 @@ def _parse_citations(response: str) -> List[int]:
57
  """Parse citation numbers from response text"""
58
  citation_pattern = r'\[(\d+)\]'
59
  matches = re.findall(citation_pattern, response)
60
- return sorted(list(set(int(match) for match in matches)))
 
 
 
 
 
 
 
 
61
 
62
  def _extract_sources(processed_results: List[Dict[str, Any]], cited_numbers: List[int]) -> List[Dict[str, Any]]:
63
  """Extract sources that were cited in the response"""
 
 
 
 
 
 
 
 
64
  if not cited_numbers:
 
65
  return []
66
 
67
  cited_sources = []
68
  for citation_num in cited_numbers:
69
  source_index = citation_num - 1
 
 
70
  if 0 <= source_index < len(processed_results):
71
- cited_sources.append(processed_results[source_index])
 
 
 
 
 
 
 
 
72
 
73
  return cited_sources
74
 
@@ -203,7 +233,7 @@ async def _call_llm(messages: list) -> str:
203
  response = await chat_model.ainvoke(messages)
204
  return response.content.strip()
205
  except Exception as e:
206
- logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
207
  raise
208
 
209
  async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
@@ -213,7 +243,7 @@ async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
213
  if hasattr(chunk, 'content') and chunk.content:
214
  yield chunk.content
215
  except Exception as e:
216
- logging.exception(f"LLM streaming failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
217
  yield f"Error: {str(e)}"
218
 
219
  # ---------------------------------------------------------------------
@@ -246,7 +276,7 @@ async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui
246
  return answer
247
 
248
  except Exception as e:
249
- logging.exception("Generation failed")
250
  error_msg = str(e)
251
  return {"error": error_msg} if chatui_format else f"Error: {error_msg}"
252
 
@@ -290,7 +320,7 @@ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]
290
  yield {"event": "end", "data": {}}
291
 
292
  except Exception as e:
293
- logging.exception("Streaming generation failed")
294
  error_msg = str(e)
295
  if chatui_format:
296
  yield {"event": "error", "data": {"error": error_msg}}
 
16
  # Local imports
17
  from .utils import getconfig, get_auth
18
 
19
+ # Set up logger
20
+ logger = logging.getLogger(__name__)
21
+
22
  # ---------------------------------------------------------------------
23
  # Configuration and Model Initialization
24
  # ---------------------------------------------------------------------
 
60
  """Parse citation numbers from response text"""
61
  citation_pattern = r'\[(\d+)\]'
62
  matches = re.findall(citation_pattern, response)
63
+ citation_numbers = sorted(list(set(int(match) for match in matches)))
64
+
65
+ # Debug logging
66
+ logger.info(f"=== CITATION PARSING DEBUG ===")
67
+ logger.info(f"Response text length: {len(response)}")
68
+ logger.info(f"Found citation matches: {matches}")
69
+ logger.info(f"Parsed citation numbers: {citation_numbers}")
70
+
71
+ return citation_numbers
72
 
73
  def _extract_sources(processed_results: List[Dict[str, Any]], cited_numbers: List[int]) -> List[Dict[str, Any]]:
74
  """Extract sources that were cited in the response"""
75
+ # Debug logging - show all available sources before filtering
76
+ logger.info(f"=== SOURCE FILTERING DEBUG ===")
77
+ logger.info(f"Total available sources: {len(processed_results)}")
78
+ for i, source in enumerate(processed_results):
79
+ logger.info(f"Source {i+1}: filename='{source.get('filename', 'Unknown')}', page='{source.get('page', 'Unknown')}', year='{source.get('year', 'Unknown')}'")
80
+
81
+ logger.info(f"Cited numbers from response: {cited_numbers}")
82
+
83
  if not cited_numbers:
84
+ logger.info("No citations found - returning empty sources list")
85
  return []
86
 
87
  cited_sources = []
88
  for citation_num in cited_numbers:
89
  source_index = citation_num - 1
90
+ logger.info(f"Processing citation [{citation_num}] -> source_index: {source_index}")
91
+
92
  if 0 <= source_index < len(processed_results):
93
+ source = processed_results[source_index]
94
+ cited_sources.append(source)
95
+ logger.info(f"✓ Added source {citation_num}: filename='{source.get('filename', 'Unknown')}', page='{source.get('page', 'Unknown')}'")
96
+ else:
97
+ logger.warning(f"✗ Citation [{citation_num}] is out of range! source_index {source_index} not in range [0, {len(processed_results)-1}]")
98
+
99
+ logger.info(f"Final filtered sources count: {len(cited_sources)}")
100
+ for i, source in enumerate(cited_sources):
101
+ logger.info(f"Filtered source {i+1}: filename='{source.get('filename', 'Unknown')}', page='{source.get('page', 'Unknown')}', year='{source.get('year', 'Unknown')}'")
102
 
103
  return cited_sources
104
 
 
233
  response = await chat_model.ainvoke(messages)
234
  return response.content.strip()
235
  except Exception as e:
236
+ logger.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
237
  raise
238
 
239
  async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
 
243
  if hasattr(chunk, 'content') and chunk.content:
244
  yield chunk.content
245
  except Exception as e:
246
+ logger.exception(f"LLM streaming failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
247
  yield f"Error: {str(e)}"
248
 
249
  # ---------------------------------------------------------------------
 
276
  return answer
277
 
278
  except Exception as e:
279
+ logger.exception("Generation failed")
280
  error_msg = str(e)
281
  return {"error": error_msg} if chatui_format else f"Error: {error_msg}"
282
 
 
320
  yield {"event": "end", "data": {}}
321
 
322
  except Exception as e:
323
+ logger.exception("Streaming generation failed")
324
  error_msg = str(e)
325
  if chatui_format:
326
  yield {"event": "error", "data": {"error": error_msg}}