import os import re from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Optional, List import pytz from langchain.schema import Document from langchain.tools import tool from .retrievers import hybrid_search from .context_enrichment import enrich_retrieved_documents from .config import logger # Canonical provider names - For HBV: SASLT only CANONICAL_PROVIDERS = ["SASLT"] # Global configuration for medical_guidelines_knowledge_tool retrieval and enrichment TOOL_K_VECTOR = 5 # Number of documents to retrieve using vector search (per provider) TOOL_K_BM25 = 2 # Number of documents to retrieve using BM25 search (per provider) TOOL_PAGES_BEFORE = 1 # Number of pages to include before each top result TOOL_PAGES_AFTER = 1 # Number of pages to include after each top result TOOL_MAX_ENRICHED = 2 # Maximum number of top documents to enrich with context (per provider) # Global variables to store context for validation _last_question = None # Stores the tool query _last_documents = None TOOL_MAX_WORKERS = max(2, min(8, (os.cpu_count() or 4))) _tool_executor = ThreadPoolExecutor(max_workers=TOOL_MAX_WORKERS) # Map lowercase variants and full names to canonical provider codes _PROVIDER_ALIASES = { "saslt": "SASLT", "saslt 2021": "SASLT", "saudi association for the study of liver diseases and transplantation": "SASLT", "saslt guidelines": "SASLT", } def _normalize_provider(provider: Optional[str], query: str) -> Optional[str]: """Normalize provider name from explicit parameter or query text.""" text = provider if provider else query if not text: return None t = text.lower() # Quick direct hits for canonical providers for canon in CANONICAL_PROVIDERS: if re.search(rf"\b{re.escape(canon.lower())}\b", t): return canon # Alias-based detection for alias, canon in _PROVIDER_ALIASES.items(): if alias in t: return canon # If explicit provider didn't match, try query text as fallback if provider and provider != query: return _normalize_provider(None, query) return None def clear_text(text: str) -> str: """Clean and normalize text by removing markdown and excess whitespace.""" if not text: return "" t = text # Normalize newlines t = t.replace("\r\n", "\n").replace("\r", "\n") # Links: keep title and URL t = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r"\1 (\2)", t) # Images: drop entirely t = re.sub(r"!\[[^\]]*\]\([^)]*\)", "", t) # Remove headers/quotes markers at line starts t = re.sub(r"(?m)^[>\s]*#{1,6}\s*", "", t) # Remove backticks/code fences and emphasis t = t.replace("```", "").replace("`", "") t = t.replace("**", "").replace("*", "").replace("_", "") # Collapse spaces before newlines t = re.sub(r"[ \t]+\n", "\n", t) # Collapse multiple newlines and spaces t = re.sub(r"\n{3,}", "\n\n", t) t = re.sub(r"[ \t]{2,}", " ", t) # Trim and truncate t = t.strip() return t def _format_docs_with_citations(docs: List[Document], group_by_provider: bool = False) -> str: """Format documents with citations.""" if not docs: return "No results." if group_by_provider: return _format_grouped_by_provider(docs) parts = [] for i, d in enumerate(docs, start=1): meta = d.metadata or {} citation = _build_citation(i, meta, d.page_content) parts.append(citation) return "\n\n".join(parts) def _build_citation(index: int, metadata: dict, content: str, include_provider: bool = True) -> str: """Build a single citation string with clean formatting.""" source = metadata.get("source", "unknown") page = metadata.get("page_number", "?") provider = metadata.get("provider", "unknown") disease = metadata.get("disease", "unknown") is_context = metadata.get("context_enrichment", False) snippet = clear_text(content) # Build citation header citation = f"📄 Result {index}:\n" # Build metadata line metadata_parts = [] if include_provider: metadata_parts.append(f"Provider: {provider}") metadata_parts.append(f"Disease: {disease}") metadata_parts.append(f"Source: {source}") metadata_parts.append(f"Page: {page}") citation += " | ".join(metadata_parts) if is_context: citation += " [CONTEXT PAGE]" citation += f"\n\n{snippet}\n" return citation def _document_to_dict(doc: Document) -> dict: """Convert a Document to a dictionary for storage.""" return { "doc_id": getattr(doc, 'id', None), "source": doc.metadata.get("source", "unknown"), "provider": doc.metadata.get("provider", "unknown"), "page_number": doc.metadata.get("page_number", "unknown"), "disease": doc.metadata.get("disease", "unknown"), "context_enrichment": doc.metadata.get("context_enrichment", False), "enriched": doc.metadata.get("enriched", False), "pages_included": doc.metadata.get("pages_included", []), "primary_page": doc.metadata.get("primary_page"), "context_pages_before": doc.metadata.get("context_pages_before"), "context_pages_after": doc.metadata.get("context_pages_after"), "content": doc.page_content } def _format_grouped_by_provider(docs: List[Document]) -> str: """Format results grouped by provider for multi-provider queries.""" if not docs: return "No results found from any guideline provider." # Group documents by provider provider_groups = {} for doc in docs: provider = doc.metadata.get("provider", "unknown") if provider not in provider_groups: provider_groups[provider] = [] provider_groups[provider].append(doc) # Format header parts = [ f"\n{'='*70}", f"SEARCH RESULTS FROM SASLT 2021 GUIDELINES", f"Retrieved information from {len(provider_groups)} guideline provider(s)", f"{'='*70}\n" ] # Format each provider's results for idx, provider in enumerate(sorted(provider_groups.keys()), start=1): provider_docs = provider_groups[provider] # Provider header parts.append(f"\n{'─'*70}") parts.append(f"🏥 PROVIDER {idx}: {provider} ({len(provider_docs)} result{'s' if len(provider_docs) != 1 else ''})") parts.append(f"{'─'*70}\n") # Format each document for this provider for i, doc in enumerate(provider_docs, start=1): meta = doc.metadata or {} citation = _build_citation(i, meta, doc.page_content, include_provider=False) parts.append(citation) if i < len(provider_docs): parts.append("") return "\n".join(parts) @tool def medical_guidelines_knowledge_tool(query: str, provider: Optional[str] = None) -> str: """ Retrieve comprehensive medical guideline knowledge with enriched context from SASLT 2021 guidelines. Includes surrounding pages (before/after) for top results to provide complete clinical context. This retrieves information from SASLT 2021 guidelines for HBV management. Returns detailed text with full metadata and contextual information for expert analysis. """ global _last_question, _last_documents try: # Store question for validation context _last_question = query # Normalize provider name from either explicit arg or query text normalized_provider = _normalize_provider(provider, query) # Query SASLT provider if not normalized_provider: logger.info("No specific provider - querying SASLT") normalized_provider = "SASLT" # Perform hybrid search docs = hybrid_search(query, normalized_provider, TOOL_K_VECTOR, TOOL_K_BM25) # Store documents for validation context _last_documents = [_document_to_dict(doc) for doc in docs] return _format_docs_with_citations(docs) except Exception as e: logger.error(f"Retrieval error: {str(e)}") return f"Retrieval error: {str(e)}" @tool def get_current_datetime_tool() -> str: """ Returns the current date, time, and day of the week for Egypt (Africa/Cairo). This is the only reliable source for date and time information. Use this tool whenever a user asks about 'today', 'now', or any other time-sensitive query. The output is always in English and in standard 12-hour format. """ try: # Define the timezone for Egypt egypt_tz = pytz.timezone('Africa/Cairo') # Get the current time in that timezone now_egypt = datetime.now(egypt_tz) # Manual mapping to ensure English output regardless of system locale days_en = { 0: "Monday", 1: "Tuesday", 2: "Wednesday", 3: "Thursday", 4: "Friday", 5: "Saturday", 6: "Sunday" } months_en = { 1: "January", 2: "February", 3: "March", 4: "April", 5: "May", 6: "June", 7: "July", 8: "August", 9: "September", 10: "October", 11: "November", 12: "December" } # Get English names using manual mapping day_name = days_en[now_egypt.weekday()] month_name = months_en[now_egypt.month] day = now_egypt.day year = now_egypt.year # Format time manually to avoid locale issues hour = now_egypt.hour minute = now_egypt.minute # Convert to 12-hour format if hour == 0: hour_12 = 12 period = "AM" elif hour < 12: hour_12 = hour period = "AM" elif hour == 12: hour_12 = 12 period = "PM" else: hour_12 = hour - 12 period = "PM" time_str = f"{hour_12:02d}:{minute:02d} {period}" # Create the final string return f"Current date and time in Egypt: {day_name}, {month_name} {day}, {year} at {time_str}" except Exception as e: return f"Error getting current datetime: {str(e)}"