moazx's picture
Initial commit with all files including LFS
73c6377
import os
import re
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Optional, List
import pytz
from langchain.schema import Document
from langchain.tools import tool
from .retrievers import hybrid_search
from .context_enrichment import enrich_retrieved_documents
from .config import logger
# Canonical provider names - For HBV: SASLT only
CANONICAL_PROVIDERS = ["SASLT"]
# Global configuration for medical_guidelines_knowledge_tool retrieval and enrichment
TOOL_K_VECTOR = 5 # Number of documents to retrieve using vector search (per provider)
TOOL_K_BM25 = 2 # Number of documents to retrieve using BM25 search (per provider)
TOOL_PAGES_BEFORE = 1 # Number of pages to include before each top result
TOOL_PAGES_AFTER = 1 # Number of pages to include after each top result
TOOL_MAX_ENRICHED = 2 # Maximum number of top documents to enrich with context (per provider)
# Global variables to store context for validation
_last_question = None # Stores the tool query
_last_documents = None
TOOL_MAX_WORKERS = max(2, min(8, (os.cpu_count() or 4)))
_tool_executor = ThreadPoolExecutor(max_workers=TOOL_MAX_WORKERS)
# Map lowercase variants and full names to canonical provider codes
_PROVIDER_ALIASES = {
"saslt": "SASLT",
"saslt 2021": "SASLT",
"saudi association for the study of liver diseases and transplantation": "SASLT",
"saslt guidelines": "SASLT",
}
def _normalize_provider(provider: Optional[str], query: str) -> Optional[str]:
"""Normalize provider name from explicit parameter or query text."""
text = provider if provider else query
if not text:
return None
t = text.lower()
# Quick direct hits for canonical providers
for canon in CANONICAL_PROVIDERS:
if re.search(rf"\b{re.escape(canon.lower())}\b", t):
return canon
# Alias-based detection
for alias, canon in _PROVIDER_ALIASES.items():
if alias in t:
return canon
# If explicit provider didn't match, try query text as fallback
if provider and provider != query:
return _normalize_provider(None, query)
return None
def clear_text(text: str) -> str:
"""Clean and normalize text by removing markdown and excess whitespace."""
if not text:
return ""
t = text
# Normalize newlines
t = t.replace("\r\n", "\n").replace("\r", "\n")
# Links: keep title and URL
t = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r"\1 (\2)", t)
# Images: drop entirely
t = re.sub(r"!\[[^\]]*\]\([^)]*\)", "", t)
# Remove headers/quotes markers at line starts
t = re.sub(r"(?m)^[>\s]*#{1,6}\s*", "", t)
# Remove backticks/code fences and emphasis
t = t.replace("```", "").replace("`", "")
t = t.replace("**", "").replace("*", "").replace("_", "")
# Collapse spaces before newlines
t = re.sub(r"[ \t]+\n", "\n", t)
# Collapse multiple newlines and spaces
t = re.sub(r"\n{3,}", "\n\n", t)
t = re.sub(r"[ \t]{2,}", " ", t)
# Trim and truncate
t = t.strip()
return t
def _format_docs_with_citations(docs: List[Document], group_by_provider: bool = False) -> str:
"""Format documents with citations."""
if not docs:
return "No results."
if group_by_provider:
return _format_grouped_by_provider(docs)
parts = []
for i, d in enumerate(docs, start=1):
meta = d.metadata or {}
citation = _build_citation(i, meta, d.page_content)
parts.append(citation)
return "\n\n".join(parts)
def _build_citation(index: int, metadata: dict, content: str, include_provider: bool = True) -> str:
"""Build a single citation string with clean formatting."""
source = metadata.get("source", "unknown")
page = metadata.get("page_number", "?")
provider = metadata.get("provider", "unknown")
disease = metadata.get("disease", "unknown")
is_context = metadata.get("context_enrichment", False)
snippet = clear_text(content)
# Build citation header
citation = f"📄 Result {index}:\n"
# Build metadata line
metadata_parts = []
if include_provider:
metadata_parts.append(f"Provider: {provider}")
metadata_parts.append(f"Disease: {disease}")
metadata_parts.append(f"Source: {source}")
metadata_parts.append(f"Page: {page}")
citation += " | ".join(metadata_parts)
if is_context:
citation += " [CONTEXT PAGE]"
citation += f"\n\n{snippet}\n"
return citation
def _document_to_dict(doc: Document) -> dict:
"""Convert a Document to a dictionary for storage."""
return {
"doc_id": getattr(doc, 'id', None),
"source": doc.metadata.get("source", "unknown"),
"provider": doc.metadata.get("provider", "unknown"),
"page_number": doc.metadata.get("page_number", "unknown"),
"disease": doc.metadata.get("disease", "unknown"),
"context_enrichment": doc.metadata.get("context_enrichment", False),
"enriched": doc.metadata.get("enriched", False),
"pages_included": doc.metadata.get("pages_included", []),
"primary_page": doc.metadata.get("primary_page"),
"context_pages_before": doc.metadata.get("context_pages_before"),
"context_pages_after": doc.metadata.get("context_pages_after"),
"content": doc.page_content
}
def _format_grouped_by_provider(docs: List[Document]) -> str:
"""Format results grouped by provider for multi-provider queries."""
if not docs:
return "No results found from any guideline provider."
# Group documents by provider
provider_groups = {}
for doc in docs:
provider = doc.metadata.get("provider", "unknown")
if provider not in provider_groups:
provider_groups[provider] = []
provider_groups[provider].append(doc)
# Format header
parts = [
f"\n{'='*70}",
f"SEARCH RESULTS FROM SASLT 2021 GUIDELINES",
f"Retrieved information from {len(provider_groups)} guideline provider(s)",
f"{'='*70}\n"
]
# Format each provider's results
for idx, provider in enumerate(sorted(provider_groups.keys()), start=1):
provider_docs = provider_groups[provider]
# Provider header
parts.append(f"\n{'─'*70}")
parts.append(f"🏥 PROVIDER {idx}: {provider} ({len(provider_docs)} result{'s' if len(provider_docs) != 1 else ''})")
parts.append(f"{'─'*70}\n")
# Format each document for this provider
for i, doc in enumerate(provider_docs, start=1):
meta = doc.metadata or {}
citation = _build_citation(i, meta, doc.page_content, include_provider=False)
parts.append(citation)
if i < len(provider_docs):
parts.append("")
return "\n".join(parts)
@tool
def medical_guidelines_knowledge_tool(query: str, provider: Optional[str] = None) -> str:
"""
Retrieve comprehensive medical guideline knowledge with enriched context from SASLT 2021 guidelines.
Includes surrounding pages (before/after) for top results to provide complete clinical context.
This retrieves information from SASLT 2021 guidelines for HBV management.
Returns detailed text with full metadata and contextual information for expert analysis.
"""
global _last_question, _last_documents
try:
# Store question for validation context
_last_question = query
# Normalize provider name from either explicit arg or query text
normalized_provider = _normalize_provider(provider, query)
# Query SASLT provider
if not normalized_provider:
logger.info("No specific provider - querying SASLT")
normalized_provider = "SASLT"
# Perform hybrid search
docs = hybrid_search(query, normalized_provider, TOOL_K_VECTOR, TOOL_K_BM25)
# Store documents for validation context
_last_documents = [_document_to_dict(doc) for doc in docs]
return _format_docs_with_citations(docs)
except Exception as e:
logger.error(f"Retrieval error: {str(e)}")
return f"Retrieval error: {str(e)}"
@tool
def get_current_datetime_tool() -> str:
"""
Returns the current date, time, and day of the week for Egypt (Africa/Cairo).
This is the only reliable source for date and time information. Use this tool
whenever a user asks about 'today', 'now', or any other time-sensitive query.
The output is always in English and in standard 12-hour format.
"""
try:
# Define the timezone for Egypt
egypt_tz = pytz.timezone('Africa/Cairo')
# Get the current time in that timezone
now_egypt = datetime.now(egypt_tz)
# Manual mapping to ensure English output regardless of system locale
days_en = {
0: "Monday", 1: "Tuesday", 2: "Wednesday", 3: "Thursday",
4: "Friday", 5: "Saturday", 6: "Sunday"
}
months_en = {
1: "January", 2: "February", 3: "March", 4: "April",
5: "May", 6: "June", 7: "July", 8: "August",
9: "September", 10: "October", 11: "November", 12: "December"
}
# Get English names using manual mapping
day_name = days_en[now_egypt.weekday()]
month_name = months_en[now_egypt.month]
day = now_egypt.day
year = now_egypt.year
# Format time manually to avoid locale issues
hour = now_egypt.hour
minute = now_egypt.minute
# Convert to 12-hour format
if hour == 0:
hour_12 = 12
period = "AM"
elif hour < 12:
hour_12 = hour
period = "AM"
elif hour == 12:
hour_12 = 12
period = "PM"
else:
hour_12 = hour - 12
period = "PM"
time_str = f"{hour_12:02d}:{minute:02d} {period}"
# Create the final string
return f"Current date and time in Egypt: {day_name}, {month_name} {day}, {year} at {time_str}"
except Exception as e:
return f"Error getting current datetime: {str(e)}"