Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
ts citation filtering
Browse files- utils/generator.py +36 -6
utils/generator.py
CHANGED
|
@@ -16,6 +16,9 @@ from langchain_core.messages import SystemMessage, HumanMessage
|
|
| 16 |
# Local imports
|
| 17 |
from .utils import getconfig, get_auth
|
| 18 |
|
|
|
|
|
|
|
|
|
|
| 19 |
# ---------------------------------------------------------------------
|
| 20 |
# Configuration and Model Initialization
|
| 21 |
# ---------------------------------------------------------------------
|
|
@@ -57,18 +60,45 @@ def _parse_citations(response: str) -> List[int]:
|
|
| 57 |
"""Parse citation numbers from response text"""
|
| 58 |
citation_pattern = r'\[(\d+)\]'
|
| 59 |
matches = re.findall(citation_pattern, response)
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def _extract_sources(processed_results: List[Dict[str, Any]], cited_numbers: List[int]) -> List[Dict[str, Any]]:
|
| 63 |
"""Extract sources that were cited in the response"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
if not cited_numbers:
|
|
|
|
| 65 |
return []
|
| 66 |
|
| 67 |
cited_sources = []
|
| 68 |
for citation_num in cited_numbers:
|
| 69 |
source_index = citation_num - 1
|
|
|
|
|
|
|
| 70 |
if 0 <= source_index < len(processed_results):
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
return cited_sources
|
| 74 |
|
|
@@ -203,7 +233,7 @@ async def _call_llm(messages: list) -> str:
|
|
| 203 |
response = await chat_model.ainvoke(messages)
|
| 204 |
return response.content.strip()
|
| 205 |
except Exception as e:
|
| 206 |
-
|
| 207 |
raise
|
| 208 |
|
| 209 |
async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
|
|
@@ -213,7 +243,7 @@ async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
|
|
| 213 |
if hasattr(chunk, 'content') and chunk.content:
|
| 214 |
yield chunk.content
|
| 215 |
except Exception as e:
|
| 216 |
-
|
| 217 |
yield f"Error: {str(e)}"
|
| 218 |
|
| 219 |
# ---------------------------------------------------------------------
|
|
@@ -246,7 +276,7 @@ async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui
|
|
| 246 |
return answer
|
| 247 |
|
| 248 |
except Exception as e:
|
| 249 |
-
|
| 250 |
error_msg = str(e)
|
| 251 |
return {"error": error_msg} if chatui_format else f"Error: {error_msg}"
|
| 252 |
|
|
@@ -290,7 +320,7 @@ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]
|
|
| 290 |
yield {"event": "end", "data": {}}
|
| 291 |
|
| 292 |
except Exception as e:
|
| 293 |
-
|
| 294 |
error_msg = str(e)
|
| 295 |
if chatui_format:
|
| 296 |
yield {"event": "error", "data": {"error": error_msg}}
|
|
|
|
| 16 |
# Local imports
|
| 17 |
from .utils import getconfig, get_auth
|
| 18 |
|
| 19 |
+
# Set up logger
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
# ---------------------------------------------------------------------
|
| 23 |
# Configuration and Model Initialization
|
| 24 |
# ---------------------------------------------------------------------
|
|
|
|
| 60 |
"""Parse citation numbers from response text"""
|
| 61 |
citation_pattern = r'\[(\d+)\]'
|
| 62 |
matches = re.findall(citation_pattern, response)
|
| 63 |
+
citation_numbers = sorted(list(set(int(match) for match in matches)))
|
| 64 |
+
|
| 65 |
+
# Debug logging
|
| 66 |
+
logger.info(f"=== CITATION PARSING DEBUG ===")
|
| 67 |
+
logger.info(f"Response text length: {len(response)}")
|
| 68 |
+
logger.info(f"Found citation matches: {matches}")
|
| 69 |
+
logger.info(f"Parsed citation numbers: {citation_numbers}")
|
| 70 |
+
|
| 71 |
+
return citation_numbers
|
| 72 |
|
| 73 |
def _extract_sources(processed_results: List[Dict[str, Any]], cited_numbers: List[int]) -> List[Dict[str, Any]]:
|
| 74 |
"""Extract sources that were cited in the response"""
|
| 75 |
+
# Debug logging - show all available sources before filtering
|
| 76 |
+
logger.info(f"=== SOURCE FILTERING DEBUG ===")
|
| 77 |
+
logger.info(f"Total available sources: {len(processed_results)}")
|
| 78 |
+
for i, source in enumerate(processed_results):
|
| 79 |
+
logger.info(f"Source {i+1}: filename='{source.get('filename', 'Unknown')}', page='{source.get('page', 'Unknown')}', year='{source.get('year', 'Unknown')}'")
|
| 80 |
+
|
| 81 |
+
logger.info(f"Cited numbers from response: {cited_numbers}")
|
| 82 |
+
|
| 83 |
if not cited_numbers:
|
| 84 |
+
logger.info("No citations found - returning empty sources list")
|
| 85 |
return []
|
| 86 |
|
| 87 |
cited_sources = []
|
| 88 |
for citation_num in cited_numbers:
|
| 89 |
source_index = citation_num - 1
|
| 90 |
+
logger.info(f"Processing citation [{citation_num}] -> source_index: {source_index}")
|
| 91 |
+
|
| 92 |
if 0 <= source_index < len(processed_results):
|
| 93 |
+
source = processed_results[source_index]
|
| 94 |
+
cited_sources.append(source)
|
| 95 |
+
logger.info(f"✓ Added source {citation_num}: filename='{source.get('filename', 'Unknown')}', page='{source.get('page', 'Unknown')}'")
|
| 96 |
+
else:
|
| 97 |
+
logger.warning(f"✗ Citation [{citation_num}] is out of range! source_index {source_index} not in range [0, {len(processed_results)-1}]")
|
| 98 |
+
|
| 99 |
+
logger.info(f"Final filtered sources count: {len(cited_sources)}")
|
| 100 |
+
for i, source in enumerate(cited_sources):
|
| 101 |
+
logger.info(f"Filtered source {i+1}: filename='{source.get('filename', 'Unknown')}', page='{source.get('page', 'Unknown')}', year='{source.get('year', 'Unknown')}'")
|
| 102 |
|
| 103 |
return cited_sources
|
| 104 |
|
|
|
|
| 233 |
response = await chat_model.ainvoke(messages)
|
| 234 |
return response.content.strip()
|
| 235 |
except Exception as e:
|
| 236 |
+
logger.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
|
| 237 |
raise
|
| 238 |
|
| 239 |
async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
|
|
|
|
| 243 |
if hasattr(chunk, 'content') and chunk.content:
|
| 244 |
yield chunk.content
|
| 245 |
except Exception as e:
|
| 246 |
+
logger.exception(f"LLM streaming failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
|
| 247 |
yield f"Error: {str(e)}"
|
| 248 |
|
| 249 |
# ---------------------------------------------------------------------
|
|
|
|
| 276 |
return answer
|
| 277 |
|
| 278 |
except Exception as e:
|
| 279 |
+
logger.exception("Generation failed")
|
| 280 |
error_msg = str(e)
|
| 281 |
return {"error": error_msg} if chatui_format else f"Error: {error_msg}"
|
| 282 |
|
|
|
|
| 320 |
yield {"event": "end", "data": {}}
|
| 321 |
|
| 322 |
except Exception as e:
|
| 323 |
+
logger.exception("Streaming generation failed")
|
| 324 |
error_msg = str(e)
|
| 325 |
if chatui_format:
|
| 326 |
yield {"event": "error", "data": {"error": error_msg}}
|