File size: 10,372 Bytes
73c6377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import os
import re
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Optional, List

import pytz
from langchain.schema import Document
from langchain.tools import tool
from .retrievers import hybrid_search
from .context_enrichment import enrich_retrieved_documents
from .config import logger

# Canonical provider names - For HBV: SASLT only
CANONICAL_PROVIDERS = ["SASLT"]

# Global configuration for medical_guidelines_knowledge_tool retrieval and enrichment
TOOL_K_VECTOR = 5  # Number of documents to retrieve using vector search (per provider)
TOOL_K_BM25 = 2    # Number of documents to retrieve using BM25 search (per provider)
TOOL_PAGES_BEFORE = 1  # Number of pages to include before each top result
TOOL_PAGES_AFTER = 1   # Number of pages to include after each top result
TOOL_MAX_ENRICHED = 2  # Maximum number of top documents to enrich with context (per provider)

# Global variables to store context for validation
_last_question = None  # Stores the tool query
_last_documents = None

TOOL_MAX_WORKERS = max(2, min(8, (os.cpu_count() or 4)))
_tool_executor = ThreadPoolExecutor(max_workers=TOOL_MAX_WORKERS)


# Map lowercase variants and full names to canonical provider codes
_PROVIDER_ALIASES = {
    "saslt": "SASLT",
    "saslt 2021": "SASLT",
    "saudi association for the study of liver diseases and transplantation": "SASLT",
    "saslt guidelines": "SASLT",
}


def _normalize_provider(provider: Optional[str], query: str) -> Optional[str]:
    """Normalize provider name from explicit parameter or query text."""
    text = provider if provider else query
    if not text:
        return None
    
    t = text.lower()
    
    # Quick direct hits for canonical providers
    for canon in CANONICAL_PROVIDERS:
        if re.search(rf"\b{re.escape(canon.lower())}\b", t):
            return canon
    
    # Alias-based detection
    for alias, canon in _PROVIDER_ALIASES.items():
        if alias in t:
            return canon
    
    # If explicit provider didn't match, try query text as fallback
    if provider and provider != query:
        return _normalize_provider(None, query)
    
    return None


def clear_text(text: str) -> str:
    """Clean and normalize text by removing markdown and excess whitespace."""
    if not text:
        return ""
    t = text
    # Normalize newlines
    t = t.replace("\r\n", "\n").replace("\r", "\n")
    # Links: keep title and URL
    t = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r"\1 (\2)", t)
    # Images: drop entirely
    t = re.sub(r"!\[[^\]]*\]\([^)]*\)", "", t)
    # Remove headers/quotes markers at line starts
    t = re.sub(r"(?m)^[>\s]*#{1,6}\s*", "", t)
    # Remove backticks/code fences and emphasis
    t = t.replace("```", "").replace("`", "")
    t = t.replace("**", "").replace("*", "").replace("_", "")
    # Collapse spaces before newlines
    t = re.sub(r"[ \t]+\n", "\n", t)
    # Collapse multiple newlines and spaces
    t = re.sub(r"\n{3,}", "\n\n", t)
    t = re.sub(r"[ \t]{2,}", " ", t)
    # Trim and truncate
    t = t.strip()
    return t


def _format_docs_with_citations(docs: List[Document], group_by_provider: bool = False) -> str:
    """Format documents with citations."""
    if not docs:
        return "No results."
    
    if group_by_provider:
        return _format_grouped_by_provider(docs)
    
    parts = []
    for i, d in enumerate(docs, start=1):
        meta = d.metadata or {}
        citation = _build_citation(i, meta, d.page_content)
        parts.append(citation)
    
    return "\n\n".join(parts)


def _build_citation(index: int, metadata: dict, content: str, include_provider: bool = True) -> str:
    """Build a single citation string with clean formatting."""
    source = metadata.get("source", "unknown")
    page = metadata.get("page_number", "?")
    provider = metadata.get("provider", "unknown")
    disease = metadata.get("disease", "unknown")
    is_context = metadata.get("context_enrichment", False)
    
    snippet = clear_text(content)
    
    # Build citation header
    citation = f"πŸ“„ Result {index}:\n"
    
    # Build metadata line
    metadata_parts = []
    if include_provider:
        metadata_parts.append(f"Provider: {provider}")
    metadata_parts.append(f"Disease: {disease}")
    metadata_parts.append(f"Source: {source}")
    metadata_parts.append(f"Page: {page}")
    
    citation += " | ".join(metadata_parts)
    
    if is_context:
        citation += " [CONTEXT PAGE]"
    
    citation += f"\n\n{snippet}\n"
    return citation


def _document_to_dict(doc: Document) -> dict:
    """Convert a Document to a dictionary for storage."""
    return {
        "doc_id": getattr(doc, 'id', None),
        "source": doc.metadata.get("source", "unknown"),
        "provider": doc.metadata.get("provider", "unknown"),
        "page_number": doc.metadata.get("page_number", "unknown"),
        "disease": doc.metadata.get("disease", "unknown"),
        "context_enrichment": doc.metadata.get("context_enrichment", False),
        "enriched": doc.metadata.get("enriched", False),
        "pages_included": doc.metadata.get("pages_included", []),
        "primary_page": doc.metadata.get("primary_page"),
        "context_pages_before": doc.metadata.get("context_pages_before"),
        "context_pages_after": doc.metadata.get("context_pages_after"),
        "content": doc.page_content
    }


def _format_grouped_by_provider(docs: List[Document]) -> str:
    """Format results grouped by provider for multi-provider queries."""
    if not docs:
        return "No results found from any guideline provider."
    
    # Group documents by provider
    provider_groups = {}
    for doc in docs:
        provider = doc.metadata.get("provider", "unknown")
        if provider not in provider_groups:
            provider_groups[provider] = []
        provider_groups[provider].append(doc)
    
    # Format header
    parts = [
        f"\n{'='*70}",
        f"SEARCH RESULTS FROM SASLT 2021 GUIDELINES",
        f"Retrieved information from {len(provider_groups)} guideline provider(s)",
        f"{'='*70}\n"
    ]
    
    # Format each provider's results
    for idx, provider in enumerate(sorted(provider_groups.keys()), start=1):
        provider_docs = provider_groups[provider]
        
        # Provider header
        parts.append(f"\n{'─'*70}")
        parts.append(f"πŸ₯ PROVIDER {idx}: {provider} ({len(provider_docs)} result{'s' if len(provider_docs) != 1 else ''})")
        parts.append(f"{'─'*70}\n")
        
        # Format each document for this provider
        for i, doc in enumerate(provider_docs, start=1):
            meta = doc.metadata or {}
            citation = _build_citation(i, meta, doc.page_content, include_provider=False)
            parts.append(citation)
            
            if i < len(provider_docs):
                parts.append("")
    
    return "\n".join(parts)


@tool
def medical_guidelines_knowledge_tool(query: str, provider: Optional[str] = None) -> str:
    """
    Retrieve comprehensive medical guideline knowledge with enriched context from SASLT 2021 guidelines.
    Includes surrounding pages (before/after) for top results to provide complete clinical context.
    
    This retrieves information from SASLT 2021 guidelines for HBV management.
    
    Returns detailed text with full metadata and contextual information for expert analysis.
    """
    global _last_question, _last_documents
    try:
        # Store question for validation context
        _last_question = query
        
        # Normalize provider name from either explicit arg or query text
        normalized_provider = _normalize_provider(provider, query)
        
        # Query SASLT provider
        if not normalized_provider:
            logger.info("No specific provider - querying SASLT")
            normalized_provider = "SASLT"
        
        # Perform hybrid search
        docs = hybrid_search(query, normalized_provider, TOOL_K_VECTOR, TOOL_K_BM25)
        
        
        # Store documents for validation context
        _last_documents = [_document_to_dict(doc) for doc in docs]
        
        return _format_docs_with_citations(docs)
    except Exception as e:
        logger.error(f"Retrieval error: {str(e)}")
        return f"Retrieval error: {str(e)}"


@tool
def get_current_datetime_tool() -> str:
    """
    Returns the current date, time, and day of the week for Egypt (Africa/Cairo).
    This is the only reliable source for date and time information. Use this tool
    whenever a user asks about 'today', 'now', or any other time-sensitive query.
    The output is always in English and in standard 12-hour format.
    """
    try:
        # Define the timezone for Egypt
        egypt_tz = pytz.timezone('Africa/Cairo')
        
        # Get the current time in that timezone
        now_egypt = datetime.now(egypt_tz)
        
        # Manual mapping to ensure English output regardless of system locale
        days_en = {
            0: "Monday", 1: "Tuesday", 2: "Wednesday", 3: "Thursday",
            4: "Friday", 5: "Saturday", 6: "Sunday"
        }
        months_en = {
            1: "January", 2: "February", 3: "March", 4: "April",
            5: "May", 6: "June", 7: "July", 8: "August",
            9: "September", 10: "October", 11: "November", 12: "December"
        }
        
        # Get English names using manual mapping
        day_name = days_en[now_egypt.weekday()]
        month_name = months_en[now_egypt.month]
        day = now_egypt.day
        year = now_egypt.year
        
        # Format time manually to avoid locale issues
        hour = now_egypt.hour
        minute = now_egypt.minute
        
        # Convert to 12-hour format
        if hour == 0:
            hour_12 = 12
            period = "AM"
        elif hour < 12:
            hour_12 = hour
            period = "AM"
        elif hour == 12:
            hour_12 = 12
            period = "PM"
        else:
            hour_12 = hour - 12
            period = "PM"
        
        time_str = f"{hour_12:02d}:{minute:02d} {period}"
        
        # Create the final string
        return f"Current date and time in Egypt: {day_name}, {month_name} {day}, {year} at {time_str}"
    
    except Exception as e:
        return f"Error getting current datetime: {str(e)}"