mtyrrell's picture
cleanup of all cache-related code/comments; README revision
99fffc4
import tempfile
import os
from models import GraphState
from datetime import datetime
from gradio_client import Client, file
import logging
import dotenv
import httpx
import json
from typing import Generator, Optional
from utils import detect_file_type, convert_context_to_list, merge_state, getconfig
from retriever_adapter import RetrieverAdapter
dotenv.load_dotenv()
logger = logging.getLogger(__name__)
# Load config
config = getconfig("params.cfg")
RETRIEVER = config.get("retriever", "RETRIEVER")
GENERATOR = config.get("generator", "GENERATOR")
INGESTOR = config.get("ingestor", "INGESTOR")
MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS"))
# Check if direct output mode is enabled
DIRECT_OUTPUT_ENABLED = config.getboolean("file_processing", "DIRECT_OUTPUT", fallback=False)
retriever_adapter = RetrieverAdapter("params.cfg")
#----------------------------------------
# LANGGRAPH NODE FUNCTIONS
#----------------------------------------
def detect_file_type_node(state: GraphState) -> GraphState:
"""Detect file type and determine workflow"""
file_type = "unknown"
workflow_type = "standard"
if state.get("file_content") and state.get("filename"):
file_type = detect_file_type(state["filename"], state["file_content"])
# Check if direct output mode is enabled
if DIRECT_OUTPUT_ENABLED:
logger.info(f"Direct output mode enabled - file will show ingestor results directly")
workflow_type = "direct_output"
else:
# Direct output disabled - use standard workflow
logger.info(f"Direct output mode disabled - using standard RAG pipeline")
workflow_type = "standard"
metadata = state.get("metadata", {})
metadata.update({
"file_type": file_type,
"workflow_type": workflow_type,
"direct_output_enabled": DIRECT_OUTPUT_ENABLED
})
return {
"file_type": file_type,
"workflow_type": workflow_type,
"metadata": metadata
}
def ingest_node(state: GraphState) -> GraphState:
"""Process file through appropriate ingestor based on file type"""
start_time = datetime.now()
if not state.get("file_content") or not state.get("filename"):
logger.info("No file provided, skipping ingestion")
return {"ingestor_context": "", "metadata": state.get("metadata", {})}
file_type = state.get("file_type", "unknown")
logger.info(f"Ingesting {file_type} file: {state['filename']}")
try:
ingestor_url = INGESTOR
logger.info(f"Using ingestor: {ingestor_url}")
client = Client(ingestor_url, hf_token=os.getenv("HF_TOKEN"))
# Create temporary file for upload
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
tmp_file.write(state["file_content"])
tmp_file_path = tmp_file.name
try:
ingestor_context = client.predict(file(tmp_file_path), api_name="/ingest")
logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"):
raise Exception(ingestor_context)
finally:
os.unlink(tmp_file_path)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"ingestion_duration": duration,
"ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
"ingestion_success": True,
"ingestor_used": ingestor_url
})
return {"ingestor_context": ingestor_context, "metadata": metadata}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Ingestion failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"ingestion_duration": duration,
"ingestion_success": False,
"ingestion_error": str(e)
})
return {"ingestor_context": "", "metadata": metadata}
def direct_output_node(state: GraphState) -> GraphState:
"""
For files when direct output mode is enabled, return ingestor results directly.
"""
file_type = state.get('file_type', 'unknown')
logger.info(f"Direct output mode - returning ingestor results for {file_type} file")
ingestor_context = state.get("ingestor_context", "")
result = ingestor_context if ingestor_context else "No results from file processing."
metadata = state.get("metadata", {})
metadata.update({
"processing_type": "direct_output",
"result_length": len(result)
})
return {"result": result, "metadata": metadata}
def retrieve_node(state: GraphState) -> GraphState:
"""Retrieve relevant context using adapter"""
start_time = datetime.now()
logger.info(f"Retrieval: {state['query'][:50]}...")
try:
# Get filters from state (provided by ChatUI or LLM agent)
filters = state.get("metadata_filters")
context = retriever_adapter.retrieve(
query=state["query"],
filters=filters,
hf_token=os.getenv("HF_TOKEN")
)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"retrieval_duration": duration,
"context_length": len(context) if context else 0,
"retrieval_success": True,
"filters_applied": filters,
"retriever_config": retriever_adapter.get_metadata()
})
return {"context": context, "metadata": metadata}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Retrieval failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"retrieval_duration": duration,
"retrieval_success": False,
"retrieval_error": str(e)
})
return {"context": "", "metadata": metadata}
async def generate_node_streaming(state: GraphState) -> Generator[GraphState, None, None]:
"""Streaming generation using generator's FastAPI endpoint"""
start_time = datetime.now()
logger.info(f"Generation (streaming): {state['query'][:50]}...")
try:
# Combine contexts
retrieved_context = state.get("context", "")
ingestor_context = state.get("ingestor_context", "")
logger.info(f"Context lengths - Ingestor: {len(ingestor_context)}, Retrieved: {len(retrieved_context)}")
# Build context list with truncation
context_list = []
total_context_chars = 0
if ingestor_context:
truncated_ingestor = (
ingestor_context[:MAX_CONTEXT_CHARS] + "...\n[Content truncated due to length]"
if len(ingestor_context) > MAX_CONTEXT_CHARS
else ingestor_context
)
context_list.append({
"answer": truncated_ingestor,
"answer_metadata": {
"filename": state.get("filename", "Uploaded Document"),
"page": "Unknown",
"year": "Unknown",
"source": "Ingestor"
}
})
total_context_chars += len(truncated_ingestor)
if retrieved_context and total_context_chars < MAX_CONTEXT_CHARS:
retrieved_list = convert_context_to_list(retrieved_context)
remaining_chars = MAX_CONTEXT_CHARS - total_context_chars
for item in retrieved_list:
item_text = item.get("answer", "")
if len(item_text) <= remaining_chars:
context_list.append(item)
remaining_chars -= len(item_text)
else:
if remaining_chars > 100:
item["answer"] = item_text[:remaining_chars-50] + "...\n[Content truncated]"
context_list.append(item)
break
final_context_size = sum(len(item.get("answer", "")) for item in context_list)
logger.info(f"Final context size: {final_context_size} characters (limit: {MAX_CONTEXT_CHARS})")
payload = {"query": state["query"], "context": context_list}
# Normalize generator URL
generator_url = GENERATOR
# Stream from generator with authentication
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"
}
async with httpx.AsyncClient(timeout=300.0, verify=False) as client:
async with client.stream(
"POST",
f"{generator_url}/generate/stream",
json=payload,
headers=headers
) as response:
if response.status_code != 200:
raise Exception(f"Generator returned status {response.status_code}")
current_text = ""
sources = None
event_type = None
async for line in response.aiter_lines():
if not line.strip():
continue
if line.startswith("event: "):
event_type = line[7:].strip()
continue
elif line.startswith("data: "):
data_content = line[6:].strip()
if event_type == "data":
try:
chunk = json.loads(data_content)
if isinstance(chunk, str):
current_text += chunk
except json.JSONDecodeError:
current_text += data_content
chunk = data_content
metadata = state.get("metadata", {})
metadata.update({
"generation_duration": (datetime.now() - start_time).total_seconds(),
"result_length": len(current_text),
"generation_success": True,
"streaming": True,
"context_chars_used": final_context_size
})
yield {"result": chunk, "metadata": metadata}
elif event_type == "sources":
try:
sources_data = json.loads(data_content)
sources = sources_data.get("sources", [])
metadata = state.get("metadata", {})
metadata.update({
"sources_received": True,
"sources_count": len(sources)
})
yield {"sources": sources, "metadata": metadata}
except json.JSONDecodeError:
logger.warning(f"Failed to parse sources: {data_content}")
elif event_type == "end":
logger.info("Generator stream ended")
break
elif event_type == "error":
try:
error_data = json.loads(data_content)
raise Exception(error_data.get("error", "Unknown error"))
except json.JSONDecodeError:
raise Exception(data_content)
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Streaming generation failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"generation_duration": duration,
"generation_success": False,
"generation_error": str(e),
"streaming": True
})
yield {"result": f"Error: {str(e)}", "metadata": metadata}
def route_workflow(state: GraphState) -> str:
"""
Conditional routing based on workflow type after ingestion.
Returns 'direct_output' when DIRECT_OUTPUT=True, 'standard' otherwise.
"""
workflow_type = state.get("workflow_type", "standard")
logger.info(f"Routing to: {workflow_type}")
return workflow_type
#----------------------------------------
# UNIFIED STREAMING PROCESSOR
#----------------------------------------
async def process_query_streaming(
query: str,
file_upload=None,
file_content: Optional[bytes] = None,
filename: Optional[str] = None,
reports_filter: str = "",
sources_filter: str = "",
subtype_filter: str = "",
year_filter: str = "",
output_format: str = "structured",
conversation_context: Optional[str] = None # NEW: conversation context
):
"""
Unified streaming function with conversation context support.
Args:
query: Latest user query
conversation_context: Optional conversation history for generation context
... (other args remain the same)
"""
# Handle file_upload if provided
if file_upload is not None:
try:
with open(file_upload.name, 'rb') as f:
file_content = f.read()
filename = os.path.basename(file_upload.name)
logger.info(f"File uploaded: {filename}, size: {len(file_content)} bytes")
except Exception as e:
logger.error(f"Error reading uploaded file: {str(e)}")
if output_format == "structured":
yield {"type": "error", "content": f"Error reading file: {str(e)}"}
else:
yield f"Error reading file: {str(e)}"
return
start_time = datetime.now()
session_id = f"stream_{start_time.strftime('%Y%m%d_%H%M%S')}"
# Log retrieval strategy
logger.info(f"Retrieval query: {query[:100]}...")
if conversation_context:
logger.info(f"Generation will use conversation context ({len(conversation_context)} chars)")
try:
# Build initial state
initial_state = {
"query": query, # Use ONLY latest query for retrieval
"context": "",
"ingestor_context": "",
"result": "",
"sources": [],
"reports_filter": reports_filter or "",
"sources_filter": sources_filter or "",
"subtype_filter": subtype_filter or "",
"year_filter": year_filter or "",
"file_content": file_content,
"filename": filename,
"file_type": "unknown",
"workflow_type": "standard",
"conversation_context": conversation_context, # Store for generation
"metadata": {
"session_id": session_id,
"start_time": start_time.isoformat(),
"has_file_attachment": file_content is not None,
"has_conversation_context": conversation_context is not None
}
}
# Execute workflow nodes
if file_content and filename:
state = merge_state(initial_state, detect_file_type_node(initial_state))
state = merge_state(state, ingest_node(state))
workflow_type = route_workflow(state)
if workflow_type == "direct_output":
final_state = direct_output_node(state)
if output_format == "structured":
yield {"type": "data", "content": final_state["result"]}
yield {"type": "end", "content": ""}
else:
yield final_state["result"]
return
else:
# Retrieve using ONLY the latest query
state = merge_state(state, retrieve_node(state))
else:
# No file: retrieve using latest query only
state = merge_state(initial_state, retrieve_node(initial_state))
# Generate response with streaming
# The generator can optionally use conversation_context for better responses
sources_collected = None
accumulated_response = "" if output_format == "gradio" else None
async for partial_state in generate_node_streaming(state):
if "result" in partial_state:
if output_format == "structured":
yield {"type": "data", "content": partial_state["result"]}
else:
accumulated_response += partial_state["result"]
yield accumulated_response
if "sources" in partial_state:
sources_collected = partial_state["sources"]
# Format and yield sources
if sources_collected:
if output_format == "structured":
yield {"type": "sources", "content": sources_collected}
else:
sources_text = "\n\n**Sources:**\n"
for i, source in enumerate(sources_collected, 1):
if isinstance(source, dict):
title = source.get('title', 'Unknown')
link = source.get('link', '#')
sources_text += f"{i}. [{title}]({link})\n"
else:
sources_text += f"{i}. {source}\n"
accumulated_response += sources_text
yield accumulated_response
if output_format == "structured":
yield {"type": "end", "content": ""}
except Exception as e:
logger.error(f"Streaming pipeline failed: {str(e)}")
if output_format == "structured":
yield {"type": "error", "content": f"Error: {str(e)}"}
else:
yield f"Error: {str(e)}"