import tempfile import os from models import GraphState from datetime import datetime from gradio_client import Client, file import logging import dotenv import httpx import json from typing import Generator, Optional from utils import detect_file_type, convert_context_to_list, merge_state, getconfig from retriever_adapter import RetrieverAdapter dotenv.load_dotenv() logger = logging.getLogger(__name__) # Load config config = getconfig("params.cfg") RETRIEVER = config.get("retriever", "RETRIEVER") GENERATOR = config.get("generator", "GENERATOR") INGESTOR = config.get("ingestor", "INGESTOR") MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS")) # Check if direct output mode is enabled DIRECT_OUTPUT_ENABLED = config.getboolean("file_processing", "DIRECT_OUTPUT", fallback=False) retriever_adapter = RetrieverAdapter("params.cfg") #---------------------------------------- # LANGGRAPH NODE FUNCTIONS #---------------------------------------- def detect_file_type_node(state: GraphState) -> GraphState: """Detect file type and determine workflow""" file_type = "unknown" workflow_type = "standard" if state.get("file_content") and state.get("filename"): file_type = detect_file_type(state["filename"], state["file_content"]) # Check if direct output mode is enabled if DIRECT_OUTPUT_ENABLED: logger.info(f"Direct output mode enabled - file will show ingestor results directly") workflow_type = "direct_output" else: # Direct output disabled - use standard workflow logger.info(f"Direct output mode disabled - using standard RAG pipeline") workflow_type = "standard" metadata = state.get("metadata", {}) metadata.update({ "file_type": file_type, "workflow_type": workflow_type, "direct_output_enabled": DIRECT_OUTPUT_ENABLED }) return { "file_type": file_type, "workflow_type": workflow_type, "metadata": metadata } def ingest_node(state: GraphState) -> GraphState: """Process file through appropriate ingestor based on file type""" start_time = datetime.now() if not state.get("file_content") or not state.get("filename"): logger.info("No file provided, skipping ingestion") return {"ingestor_context": "", "metadata": state.get("metadata", {})} file_type = state.get("file_type", "unknown") logger.info(f"Ingesting {file_type} file: {state['filename']}") try: ingestor_url = INGESTOR logger.info(f"Using ingestor: {ingestor_url}") client = Client(ingestor_url, hf_token=os.getenv("HF_TOKEN")) # Create temporary file for upload with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file: tmp_file.write(state["file_content"]) tmp_file_path = tmp_file.name try: ingestor_context = client.predict(file(tmp_file_path), api_name="/ingest") logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}") if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"): raise Exception(ingestor_context) finally: os.unlink(tmp_file_path) duration = (datetime.now() - start_time).total_seconds() metadata = state.get("metadata", {}) metadata.update({ "ingestion_duration": duration, "ingestor_context_length": len(ingestor_context) if ingestor_context else 0, "ingestion_success": True, "ingestor_used": ingestor_url }) return {"ingestor_context": ingestor_context, "metadata": metadata} except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Ingestion failed: {str(e)}") metadata = state.get("metadata", {}) metadata.update({ "ingestion_duration": duration, "ingestion_success": False, "ingestion_error": str(e) }) return {"ingestor_context": "", "metadata": metadata} def direct_output_node(state: GraphState) -> GraphState: """ For files when direct output mode is enabled, return ingestor results directly. """ file_type = state.get('file_type', 'unknown') logger.info(f"Direct output mode - returning ingestor results for {file_type} file") ingestor_context = state.get("ingestor_context", "") result = ingestor_context if ingestor_context else "No results from file processing." metadata = state.get("metadata", {}) metadata.update({ "processing_type": "direct_output", "result_length": len(result) }) return {"result": result, "metadata": metadata} def retrieve_node(state: GraphState) -> GraphState: """Retrieve relevant context using adapter""" start_time = datetime.now() logger.info(f"Retrieval: {state['query'][:50]}...") try: # Get filters from state (provided by ChatUI or LLM agent) filters = state.get("metadata_filters") context = retriever_adapter.retrieve( query=state["query"], filters=filters, hf_token=os.getenv("HF_TOKEN") ) duration = (datetime.now() - start_time).total_seconds() metadata = state.get("metadata", {}) metadata.update({ "retrieval_duration": duration, "context_length": len(context) if context else 0, "retrieval_success": True, "filters_applied": filters, "retriever_config": retriever_adapter.get_metadata() }) return {"context": context, "metadata": metadata} except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Retrieval failed: {str(e)}") metadata = state.get("metadata", {}) metadata.update({ "retrieval_duration": duration, "retrieval_success": False, "retrieval_error": str(e) }) return {"context": "", "metadata": metadata} async def generate_node_streaming(state: GraphState) -> Generator[GraphState, None, None]: """Streaming generation using generator's FastAPI endpoint""" start_time = datetime.now() logger.info(f"Generation (streaming): {state['query'][:50]}...") try: # Combine contexts retrieved_context = state.get("context", "") ingestor_context = state.get("ingestor_context", "") logger.info(f"Context lengths - Ingestor: {len(ingestor_context)}, Retrieved: {len(retrieved_context)}") # Build context list with truncation context_list = [] total_context_chars = 0 if ingestor_context: truncated_ingestor = ( ingestor_context[:MAX_CONTEXT_CHARS] + "...\n[Content truncated due to length]" if len(ingestor_context) > MAX_CONTEXT_CHARS else ingestor_context ) context_list.append({ "answer": truncated_ingestor, "answer_metadata": { "filename": state.get("filename", "Uploaded Document"), "page": "Unknown", "year": "Unknown", "source": "Ingestor" } }) total_context_chars += len(truncated_ingestor) if retrieved_context and total_context_chars < MAX_CONTEXT_CHARS: retrieved_list = convert_context_to_list(retrieved_context) remaining_chars = MAX_CONTEXT_CHARS - total_context_chars for item in retrieved_list: item_text = item.get("answer", "") if len(item_text) <= remaining_chars: context_list.append(item) remaining_chars -= len(item_text) else: if remaining_chars > 100: item["answer"] = item_text[:remaining_chars-50] + "...\n[Content truncated]" context_list.append(item) break final_context_size = sum(len(item.get("answer", "")) for item in context_list) logger.info(f"Final context size: {final_context_size} characters (limit: {MAX_CONTEXT_CHARS})") payload = {"query": state["query"], "context": context_list} # Normalize generator URL generator_url = GENERATOR # Stream from generator with authentication headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.getenv('HF_TOKEN')}" } async with httpx.AsyncClient(timeout=300.0, verify=False) as client: async with client.stream( "POST", f"{generator_url}/generate/stream", json=payload, headers=headers ) as response: if response.status_code != 200: raise Exception(f"Generator returned status {response.status_code}") current_text = "" sources = None event_type = None async for line in response.aiter_lines(): if not line.strip(): continue if line.startswith("event: "): event_type = line[7:].strip() continue elif line.startswith("data: "): data_content = line[6:].strip() if event_type == "data": try: chunk = json.loads(data_content) if isinstance(chunk, str): current_text += chunk except json.JSONDecodeError: current_text += data_content chunk = data_content metadata = state.get("metadata", {}) metadata.update({ "generation_duration": (datetime.now() - start_time).total_seconds(), "result_length": len(current_text), "generation_success": True, "streaming": True, "context_chars_used": final_context_size }) yield {"result": chunk, "metadata": metadata} elif event_type == "sources": try: sources_data = json.loads(data_content) sources = sources_data.get("sources", []) metadata = state.get("metadata", {}) metadata.update({ "sources_received": True, "sources_count": len(sources) }) yield {"sources": sources, "metadata": metadata} except json.JSONDecodeError: logger.warning(f"Failed to parse sources: {data_content}") elif event_type == "end": logger.info("Generator stream ended") break elif event_type == "error": try: error_data = json.loads(data_content) raise Exception(error_data.get("error", "Unknown error")) except json.JSONDecodeError: raise Exception(data_content) except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Streaming generation failed: {str(e)}") metadata = state.get("metadata", {}) metadata.update({ "generation_duration": duration, "generation_success": False, "generation_error": str(e), "streaming": True }) yield {"result": f"Error: {str(e)}", "metadata": metadata} def route_workflow(state: GraphState) -> str: """ Conditional routing based on workflow type after ingestion. Returns 'direct_output' when DIRECT_OUTPUT=True, 'standard' otherwise. """ workflow_type = state.get("workflow_type", "standard") logger.info(f"Routing to: {workflow_type}") return workflow_type #---------------------------------------- # UNIFIED STREAMING PROCESSOR #---------------------------------------- async def process_query_streaming( query: str, file_upload=None, file_content: Optional[bytes] = None, filename: Optional[str] = None, reports_filter: str = "", sources_filter: str = "", subtype_filter: str = "", year_filter: str = "", output_format: str = "structured", conversation_context: Optional[str] = None # NEW: conversation context ): """ Unified streaming function with conversation context support. Args: query: Latest user query conversation_context: Optional conversation history for generation context ... (other args remain the same) """ # Handle file_upload if provided if file_upload is not None: try: with open(file_upload.name, 'rb') as f: file_content = f.read() filename = os.path.basename(file_upload.name) logger.info(f"File uploaded: {filename}, size: {len(file_content)} bytes") except Exception as e: logger.error(f"Error reading uploaded file: {str(e)}") if output_format == "structured": yield {"type": "error", "content": f"Error reading file: {str(e)}"} else: yield f"Error reading file: {str(e)}" return start_time = datetime.now() session_id = f"stream_{start_time.strftime('%Y%m%d_%H%M%S')}" # Log retrieval strategy logger.info(f"Retrieval query: {query[:100]}...") if conversation_context: logger.info(f"Generation will use conversation context ({len(conversation_context)} chars)") try: # Build initial state initial_state = { "query": query, # Use ONLY latest query for retrieval "context": "", "ingestor_context": "", "result": "", "sources": [], "reports_filter": reports_filter or "", "sources_filter": sources_filter or "", "subtype_filter": subtype_filter or "", "year_filter": year_filter or "", "file_content": file_content, "filename": filename, "file_type": "unknown", "workflow_type": "standard", "conversation_context": conversation_context, # Store for generation "metadata": { "session_id": session_id, "start_time": start_time.isoformat(), "has_file_attachment": file_content is not None, "has_conversation_context": conversation_context is not None } } # Execute workflow nodes if file_content and filename: state = merge_state(initial_state, detect_file_type_node(initial_state)) state = merge_state(state, ingest_node(state)) workflow_type = route_workflow(state) if workflow_type == "direct_output": final_state = direct_output_node(state) if output_format == "structured": yield {"type": "data", "content": final_state["result"]} yield {"type": "end", "content": ""} else: yield final_state["result"] return else: # Retrieve using ONLY the latest query state = merge_state(state, retrieve_node(state)) else: # No file: retrieve using latest query only state = merge_state(initial_state, retrieve_node(initial_state)) # Generate response with streaming # The generator can optionally use conversation_context for better responses sources_collected = None accumulated_response = "" if output_format == "gradio" else None async for partial_state in generate_node_streaming(state): if "result" in partial_state: if output_format == "structured": yield {"type": "data", "content": partial_state["result"]} else: accumulated_response += partial_state["result"] yield accumulated_response if "sources" in partial_state: sources_collected = partial_state["sources"] # Format and yield sources if sources_collected: if output_format == "structured": yield {"type": "sources", "content": sources_collected} else: sources_text = "\n\n**Sources:**\n" for i, source in enumerate(sources_collected, 1): if isinstance(source, dict): title = source.get('title', 'Unknown') link = source.get('link', '#') sources_text += f"{i}. [{title}]({link})\n" else: sources_text += f"{i}. {source}\n" accumulated_response += sources_text yield accumulated_response if output_format == "structured": yield {"type": "end", "content": ""} except Exception as e: logger.error(f"Streaming pipeline failed: {str(e)}") if output_format == "structured": yield {"type": "error", "content": f"Error: {str(e)}"} else: yield f"Error: {str(e)}"