Spaces:
Running
Running
| import os | |
| import re | |
| from concurrent.futures import ThreadPoolExecutor | |
| from datetime import datetime | |
| from typing import Optional, List | |
| import pytz | |
| from langchain.schema import Document | |
| from langchain.tools import tool | |
| from .retrievers import hybrid_search | |
| from .context_enrichment import enrich_retrieved_documents | |
| from .config import logger | |
| # Canonical provider names - For HBV: SASLT only | |
| CANONICAL_PROVIDERS = ["SASLT"] | |
| # Global configuration for medical_guidelines_knowledge_tool retrieval and enrichment | |
| TOOL_K_VECTOR = 5 # Number of documents to retrieve using vector search (per provider) | |
| TOOL_K_BM25 = 2 # Number of documents to retrieve using BM25 search (per provider) | |
| TOOL_PAGES_BEFORE = 1 # Number of pages to include before each top result | |
| TOOL_PAGES_AFTER = 1 # Number of pages to include after each top result | |
| TOOL_MAX_ENRICHED = 2 # Maximum number of top documents to enrich with context (per provider) | |
| # Global variables to store context for validation | |
| _last_question = None # Stores the tool query | |
| _last_documents = None | |
| TOOL_MAX_WORKERS = max(2, min(8, (os.cpu_count() or 4))) | |
| _tool_executor = ThreadPoolExecutor(max_workers=TOOL_MAX_WORKERS) | |
| # Map lowercase variants and full names to canonical provider codes | |
| _PROVIDER_ALIASES = { | |
| "saslt": "SASLT", | |
| "saslt 2021": "SASLT", | |
| "saudi association for the study of liver diseases and transplantation": "SASLT", | |
| "saslt guidelines": "SASLT", | |
| } | |
| def _normalize_provider(provider: Optional[str], query: str) -> Optional[str]: | |
| """Normalize provider name from explicit parameter or query text.""" | |
| text = provider if provider else query | |
| if not text: | |
| return None | |
| t = text.lower() | |
| # Quick direct hits for canonical providers | |
| for canon in CANONICAL_PROVIDERS: | |
| if re.search(rf"\b{re.escape(canon.lower())}\b", t): | |
| return canon | |
| # Alias-based detection | |
| for alias, canon in _PROVIDER_ALIASES.items(): | |
| if alias in t: | |
| return canon | |
| # If explicit provider didn't match, try query text as fallback | |
| if provider and provider != query: | |
| return _normalize_provider(None, query) | |
| return None | |
| def clear_text(text: str) -> str: | |
| """Clean and normalize text by removing markdown and excess whitespace.""" | |
| if not text: | |
| return "" | |
| t = text | |
| # Normalize newlines | |
| t = t.replace("\r\n", "\n").replace("\r", "\n") | |
| # Links: keep title and URL | |
| t = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r"\1 (\2)", t) | |
| # Images: drop entirely | |
| t = re.sub(r"!\[[^\]]*\]\([^)]*\)", "", t) | |
| # Remove headers/quotes markers at line starts | |
| t = re.sub(r"(?m)^[>\s]*#{1,6}\s*", "", t) | |
| # Remove backticks/code fences and emphasis | |
| t = t.replace("```", "").replace("`", "") | |
| t = t.replace("**", "").replace("*", "").replace("_", "") | |
| # Collapse spaces before newlines | |
| t = re.sub(r"[ \t]+\n", "\n", t) | |
| # Collapse multiple newlines and spaces | |
| t = re.sub(r"\n{3,}", "\n\n", t) | |
| t = re.sub(r"[ \t]{2,}", " ", t) | |
| # Trim and truncate | |
| t = t.strip() | |
| return t | |
| def _format_docs_with_citations(docs: List[Document], group_by_provider: bool = False) -> str: | |
| """Format documents with citations.""" | |
| if not docs: | |
| return "No results." | |
| if group_by_provider: | |
| return _format_grouped_by_provider(docs) | |
| parts = [] | |
| for i, d in enumerate(docs, start=1): | |
| meta = d.metadata or {} | |
| citation = _build_citation(i, meta, d.page_content) | |
| parts.append(citation) | |
| return "\n\n".join(parts) | |
| def _build_citation(index: int, metadata: dict, content: str, include_provider: bool = True) -> str: | |
| """Build a single citation string with clean formatting.""" | |
| source = metadata.get("source", "unknown") | |
| page = metadata.get("page_number", "?") | |
| provider = metadata.get("provider", "unknown") | |
| disease = metadata.get("disease", "unknown") | |
| is_context = metadata.get("context_enrichment", False) | |
| snippet = clear_text(content) | |
| # Build citation header | |
| citation = f"📄 Result {index}:\n" | |
| # Build metadata line | |
| metadata_parts = [] | |
| if include_provider: | |
| metadata_parts.append(f"Provider: {provider}") | |
| metadata_parts.append(f"Disease: {disease}") | |
| metadata_parts.append(f"Source: {source}") | |
| metadata_parts.append(f"Page: {page}") | |
| citation += " | ".join(metadata_parts) | |
| if is_context: | |
| citation += " [CONTEXT PAGE]" | |
| citation += f"\n\n{snippet}\n" | |
| return citation | |
| def _document_to_dict(doc: Document) -> dict: | |
| """Convert a Document to a dictionary for storage.""" | |
| return { | |
| "doc_id": getattr(doc, 'id', None), | |
| "source": doc.metadata.get("source", "unknown"), | |
| "provider": doc.metadata.get("provider", "unknown"), | |
| "page_number": doc.metadata.get("page_number", "unknown"), | |
| "disease": doc.metadata.get("disease", "unknown"), | |
| "context_enrichment": doc.metadata.get("context_enrichment", False), | |
| "enriched": doc.metadata.get("enriched", False), | |
| "pages_included": doc.metadata.get("pages_included", []), | |
| "primary_page": doc.metadata.get("primary_page"), | |
| "context_pages_before": doc.metadata.get("context_pages_before"), | |
| "context_pages_after": doc.metadata.get("context_pages_after"), | |
| "content": doc.page_content | |
| } | |
| def _format_grouped_by_provider(docs: List[Document]) -> str: | |
| """Format results grouped by provider for multi-provider queries.""" | |
| if not docs: | |
| return "No results found from any guideline provider." | |
| # Group documents by provider | |
| provider_groups = {} | |
| for doc in docs: | |
| provider = doc.metadata.get("provider", "unknown") | |
| if provider not in provider_groups: | |
| provider_groups[provider] = [] | |
| provider_groups[provider].append(doc) | |
| # Format header | |
| parts = [ | |
| f"\n{'='*70}", | |
| f"SEARCH RESULTS FROM SASLT 2021 GUIDELINES", | |
| f"Retrieved information from {len(provider_groups)} guideline provider(s)", | |
| f"{'='*70}\n" | |
| ] | |
| # Format each provider's results | |
| for idx, provider in enumerate(sorted(provider_groups.keys()), start=1): | |
| provider_docs = provider_groups[provider] | |
| # Provider header | |
| parts.append(f"\n{'─'*70}") | |
| parts.append(f"🏥 PROVIDER {idx}: {provider} ({len(provider_docs)} result{'s' if len(provider_docs) != 1 else ''})") | |
| parts.append(f"{'─'*70}\n") | |
| # Format each document for this provider | |
| for i, doc in enumerate(provider_docs, start=1): | |
| meta = doc.metadata or {} | |
| citation = _build_citation(i, meta, doc.page_content, include_provider=False) | |
| parts.append(citation) | |
| if i < len(provider_docs): | |
| parts.append("") | |
| return "\n".join(parts) | |
| def medical_guidelines_knowledge_tool(query: str, provider: Optional[str] = None) -> str: | |
| """ | |
| Retrieve comprehensive medical guideline knowledge with enriched context from SASLT 2021 guidelines. | |
| Includes surrounding pages (before/after) for top results to provide complete clinical context. | |
| This retrieves information from SASLT 2021 guidelines for HBV management. | |
| Returns detailed text with full metadata and contextual information for expert analysis. | |
| """ | |
| global _last_question, _last_documents | |
| try: | |
| # Store question for validation context | |
| _last_question = query | |
| # Normalize provider name from either explicit arg or query text | |
| normalized_provider = _normalize_provider(provider, query) | |
| # Query SASLT provider | |
| if not normalized_provider: | |
| logger.info("No specific provider - querying SASLT") | |
| normalized_provider = "SASLT" | |
| # Perform hybrid search | |
| docs = hybrid_search(query, normalized_provider, TOOL_K_VECTOR, TOOL_K_BM25) | |
| # Store documents for validation context | |
| _last_documents = [_document_to_dict(doc) for doc in docs] | |
| return _format_docs_with_citations(docs) | |
| except Exception as e: | |
| logger.error(f"Retrieval error: {str(e)}") | |
| return f"Retrieval error: {str(e)}" | |
| def get_current_datetime_tool() -> str: | |
| """ | |
| Returns the current date, time, and day of the week for Egypt (Africa/Cairo). | |
| This is the only reliable source for date and time information. Use this tool | |
| whenever a user asks about 'today', 'now', or any other time-sensitive query. | |
| The output is always in English and in standard 12-hour format. | |
| """ | |
| try: | |
| # Define the timezone for Egypt | |
| egypt_tz = pytz.timezone('Africa/Cairo') | |
| # Get the current time in that timezone | |
| now_egypt = datetime.now(egypt_tz) | |
| # Manual mapping to ensure English output regardless of system locale | |
| days_en = { | |
| 0: "Monday", 1: "Tuesday", 2: "Wednesday", 3: "Thursday", | |
| 4: "Friday", 5: "Saturday", 6: "Sunday" | |
| } | |
| months_en = { | |
| 1: "January", 2: "February", 3: "March", 4: "April", | |
| 5: "May", 6: "June", 7: "July", 8: "August", | |
| 9: "September", 10: "October", 11: "November", 12: "December" | |
| } | |
| # Get English names using manual mapping | |
| day_name = days_en[now_egypt.weekday()] | |
| month_name = months_en[now_egypt.month] | |
| day = now_egypt.day | |
| year = now_egypt.year | |
| # Format time manually to avoid locale issues | |
| hour = now_egypt.hour | |
| minute = now_egypt.minute | |
| # Convert to 12-hour format | |
| if hour == 0: | |
| hour_12 = 12 | |
| period = "AM" | |
| elif hour < 12: | |
| hour_12 = hour | |
| period = "AM" | |
| elif hour == 12: | |
| hour_12 = 12 | |
| period = "PM" | |
| else: | |
| hour_12 = hour - 12 | |
| period = "PM" | |
| time_str = f"{hour_12:02d}:{minute:02d} {period}" | |
| # Create the final string | |
| return f"Current date and time in Egypt: {day_name}, {month_name} {day}, {year} at {time_str}" | |
| except Exception as e: | |
| return f"Error getting current datetime: {str(e)}" | |