Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import tempfile | |
| import os | |
| from models import GraphState | |
| from datetime import datetime | |
| from gradio_client import Client, file | |
| import logging | |
| import dotenv | |
| import httpx | |
| import json | |
| from typing import Generator, Optional | |
| from utils import detect_file_type, convert_context_to_list, merge_state, getconfig | |
| from retriever_adapter import RetrieverAdapter | |
| dotenv.load_dotenv() | |
| logger = logging.getLogger(__name__) | |
| # Load config | |
| config = getconfig("params.cfg") | |
| RETRIEVER = config.get("retriever", "RETRIEVER") | |
| GENERATOR = config.get("generator", "GENERATOR") | |
| INGESTOR = config.get("ingestor", "INGESTOR") | |
| MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS")) | |
| # Check if direct output mode is enabled | |
| DIRECT_OUTPUT_ENABLED = config.getboolean("file_processing", "DIRECT_OUTPUT", fallback=False) | |
| retriever_adapter = RetrieverAdapter("params.cfg") | |
| #---------------------------------------- | |
| # LANGGRAPH NODE FUNCTIONS | |
| #---------------------------------------- | |
| def detect_file_type_node(state: GraphState) -> GraphState: | |
| """Detect file type and determine workflow""" | |
| file_type = "unknown" | |
| workflow_type = "standard" | |
| if state.get("file_content") and state.get("filename"): | |
| file_type = detect_file_type(state["filename"], state["file_content"]) | |
| # Check if direct output mode is enabled | |
| if DIRECT_OUTPUT_ENABLED: | |
| logger.info(f"Direct output mode enabled - file will show ingestor results directly") | |
| workflow_type = "direct_output" | |
| else: | |
| # Direct output disabled - use standard workflow | |
| logger.info(f"Direct output mode disabled - using standard RAG pipeline") | |
| workflow_type = "standard" | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "file_type": file_type, | |
| "workflow_type": workflow_type, | |
| "direct_output_enabled": DIRECT_OUTPUT_ENABLED | |
| }) | |
| return { | |
| "file_type": file_type, | |
| "workflow_type": workflow_type, | |
| "metadata": metadata | |
| } | |
| def ingest_node(state: GraphState) -> GraphState: | |
| """Process file through appropriate ingestor based on file type""" | |
| start_time = datetime.now() | |
| if not state.get("file_content") or not state.get("filename"): | |
| logger.info("No file provided, skipping ingestion") | |
| return {"ingestor_context": "", "metadata": state.get("metadata", {})} | |
| file_type = state.get("file_type", "unknown") | |
| logger.info(f"Ingesting {file_type} file: {state['filename']}") | |
| try: | |
| ingestor_url = INGESTOR | |
| logger.info(f"Using ingestor: {ingestor_url}") | |
| client = Client(ingestor_url, hf_token=os.getenv("HF_TOKEN")) | |
| # Create temporary file for upload | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file: | |
| tmp_file.write(state["file_content"]) | |
| tmp_file_path = tmp_file.name | |
| try: | |
| ingestor_context = client.predict(file(tmp_file_path), api_name="/ingest") | |
| logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}") | |
| if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"): | |
| raise Exception(ingestor_context) | |
| finally: | |
| os.unlink(tmp_file_path) | |
| duration = (datetime.now() - start_time).total_seconds() | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "ingestion_duration": duration, | |
| "ingestor_context_length": len(ingestor_context) if ingestor_context else 0, | |
| "ingestion_success": True, | |
| "ingestor_used": ingestor_url | |
| }) | |
| return {"ingestor_context": ingestor_context, "metadata": metadata} | |
| except Exception as e: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| logger.error(f"Ingestion failed: {str(e)}") | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "ingestion_duration": duration, | |
| "ingestion_success": False, | |
| "ingestion_error": str(e) | |
| }) | |
| return {"ingestor_context": "", "metadata": metadata} | |
| def direct_output_node(state: GraphState) -> GraphState: | |
| """ | |
| For files when direct output mode is enabled, return ingestor results directly. | |
| """ | |
| file_type = state.get('file_type', 'unknown') | |
| logger.info(f"Direct output mode - returning ingestor results for {file_type} file") | |
| ingestor_context = state.get("ingestor_context", "") | |
| result = ingestor_context if ingestor_context else "No results from file processing." | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "processing_type": "direct_output", | |
| "result_length": len(result) | |
| }) | |
| return {"result": result, "metadata": metadata} | |
| def retrieve_node(state: GraphState) -> GraphState: | |
| """Retrieve relevant context using adapter""" | |
| start_time = datetime.now() | |
| logger.info(f"Retrieval: {state['query'][:50]}...") | |
| try: | |
| # Get filters from state (provided by ChatUI or LLM agent) | |
| filters = state.get("metadata_filters") | |
| context = retriever_adapter.retrieve( | |
| query=state["query"], | |
| filters=filters, | |
| hf_token=os.getenv("HF_TOKEN") | |
| ) | |
| duration = (datetime.now() - start_time).total_seconds() | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "retrieval_duration": duration, | |
| "context_length": len(context) if context else 0, | |
| "retrieval_success": True, | |
| "filters_applied": filters, | |
| "retriever_config": retriever_adapter.get_metadata() | |
| }) | |
| return {"context": context, "metadata": metadata} | |
| except Exception as e: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| logger.error(f"Retrieval failed: {str(e)}") | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "retrieval_duration": duration, | |
| "retrieval_success": False, | |
| "retrieval_error": str(e) | |
| }) | |
| return {"context": "", "metadata": metadata} | |
| async def generate_node_streaming(state: GraphState) -> Generator[GraphState, None, None]: | |
| """Streaming generation using generator's FastAPI endpoint""" | |
| start_time = datetime.now() | |
| logger.info(f"Generation (streaming): {state['query'][:50]}...") | |
| try: | |
| # Combine contexts | |
| retrieved_context = state.get("context", "") | |
| ingestor_context = state.get("ingestor_context", "") | |
| logger.info(f"Context lengths - Ingestor: {len(ingestor_context)}, Retrieved: {len(retrieved_context)}") | |
| # Build context list with truncation | |
| context_list = [] | |
| total_context_chars = 0 | |
| if ingestor_context: | |
| truncated_ingestor = ( | |
| ingestor_context[:MAX_CONTEXT_CHARS] + "...\n[Content truncated due to length]" | |
| if len(ingestor_context) > MAX_CONTEXT_CHARS | |
| else ingestor_context | |
| ) | |
| context_list.append({ | |
| "answer": truncated_ingestor, | |
| "answer_metadata": { | |
| "filename": state.get("filename", "Uploaded Document"), | |
| "page": "Unknown", | |
| "year": "Unknown", | |
| "source": "Ingestor" | |
| } | |
| }) | |
| total_context_chars += len(truncated_ingestor) | |
| if retrieved_context and total_context_chars < MAX_CONTEXT_CHARS: | |
| retrieved_list = convert_context_to_list(retrieved_context) | |
| remaining_chars = MAX_CONTEXT_CHARS - total_context_chars | |
| for item in retrieved_list: | |
| item_text = item.get("answer", "") | |
| if len(item_text) <= remaining_chars: | |
| context_list.append(item) | |
| remaining_chars -= len(item_text) | |
| else: | |
| if remaining_chars > 100: | |
| item["answer"] = item_text[:remaining_chars-50] + "...\n[Content truncated]" | |
| context_list.append(item) | |
| break | |
| final_context_size = sum(len(item.get("answer", "")) for item in context_list) | |
| logger.info(f"Final context size: {final_context_size} characters (limit: {MAX_CONTEXT_CHARS})") | |
| payload = {"query": state["query"], "context": context_list} | |
| # Normalize generator URL | |
| generator_url = GENERATOR | |
| # Stream from generator with authentication | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {os.getenv('HF_TOKEN')}" | |
| } | |
| async with httpx.AsyncClient(timeout=300.0, verify=False) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{generator_url}/generate/stream", | |
| json=payload, | |
| headers=headers | |
| ) as response: | |
| if response.status_code != 200: | |
| raise Exception(f"Generator returned status {response.status_code}") | |
| current_text = "" | |
| sources = None | |
| event_type = None | |
| async for line in response.aiter_lines(): | |
| if not line.strip(): | |
| continue | |
| if line.startswith("event: "): | |
| event_type = line[7:].strip() | |
| continue | |
| elif line.startswith("data: "): | |
| data_content = line[6:].strip() | |
| if event_type == "data": | |
| try: | |
| chunk = json.loads(data_content) | |
| if isinstance(chunk, str): | |
| current_text += chunk | |
| except json.JSONDecodeError: | |
| current_text += data_content | |
| chunk = data_content | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "generation_duration": (datetime.now() - start_time).total_seconds(), | |
| "result_length": len(current_text), | |
| "generation_success": True, | |
| "streaming": True, | |
| "context_chars_used": final_context_size | |
| }) | |
| yield {"result": chunk, "metadata": metadata} | |
| elif event_type == "sources": | |
| try: | |
| sources_data = json.loads(data_content) | |
| sources = sources_data.get("sources", []) | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "sources_received": True, | |
| "sources_count": len(sources) | |
| }) | |
| yield {"sources": sources, "metadata": metadata} | |
| except json.JSONDecodeError: | |
| logger.warning(f"Failed to parse sources: {data_content}") | |
| elif event_type == "end": | |
| logger.info("Generator stream ended") | |
| break | |
| elif event_type == "error": | |
| try: | |
| error_data = json.loads(data_content) | |
| raise Exception(error_data.get("error", "Unknown error")) | |
| except json.JSONDecodeError: | |
| raise Exception(data_content) | |
| except Exception as e: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| logger.error(f"Streaming generation failed: {str(e)}") | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "generation_duration": duration, | |
| "generation_success": False, | |
| "generation_error": str(e), | |
| "streaming": True | |
| }) | |
| yield {"result": f"Error: {str(e)}", "metadata": metadata} | |
| def route_workflow(state: GraphState) -> str: | |
| """ | |
| Conditional routing based on workflow type after ingestion. | |
| Returns 'direct_output' when DIRECT_OUTPUT=True, 'standard' otherwise. | |
| """ | |
| workflow_type = state.get("workflow_type", "standard") | |
| logger.info(f"Routing to: {workflow_type}") | |
| return workflow_type | |
| #---------------------------------------- | |
| # UNIFIED STREAMING PROCESSOR | |
| #---------------------------------------- | |
| async def process_query_streaming( | |
| query: str, | |
| file_upload=None, | |
| file_content: Optional[bytes] = None, | |
| filename: Optional[str] = None, | |
| reports_filter: str = "", | |
| sources_filter: str = "", | |
| subtype_filter: str = "", | |
| year_filter: str = "", | |
| output_format: str = "structured", | |
| conversation_context: Optional[str] = None # NEW: conversation context | |
| ): | |
| """ | |
| Unified streaming function with conversation context support. | |
| Args: | |
| query: Latest user query | |
| conversation_context: Optional conversation history for generation context | |
| ... (other args remain the same) | |
| """ | |
| # Handle file_upload if provided | |
| if file_upload is not None: | |
| try: | |
| with open(file_upload.name, 'rb') as f: | |
| file_content = f.read() | |
| filename = os.path.basename(file_upload.name) | |
| logger.info(f"File uploaded: {filename}, size: {len(file_content)} bytes") | |
| except Exception as e: | |
| logger.error(f"Error reading uploaded file: {str(e)}") | |
| if output_format == "structured": | |
| yield {"type": "error", "content": f"Error reading file: {str(e)}"} | |
| else: | |
| yield f"Error reading file: {str(e)}" | |
| return | |
| start_time = datetime.now() | |
| session_id = f"stream_{start_time.strftime('%Y%m%d_%H%M%S')}" | |
| # Log retrieval strategy | |
| logger.info(f"Retrieval query: {query[:100]}...") | |
| if conversation_context: | |
| logger.info(f"Generation will use conversation context ({len(conversation_context)} chars)") | |
| try: | |
| # Build initial state | |
| initial_state = { | |
| "query": query, # Use ONLY latest query for retrieval | |
| "context": "", | |
| "ingestor_context": "", | |
| "result": "", | |
| "sources": [], | |
| "reports_filter": reports_filter or "", | |
| "sources_filter": sources_filter or "", | |
| "subtype_filter": subtype_filter or "", | |
| "year_filter": year_filter or "", | |
| "file_content": file_content, | |
| "filename": filename, | |
| "file_type": "unknown", | |
| "workflow_type": "standard", | |
| "conversation_context": conversation_context, # Store for generation | |
| "metadata": { | |
| "session_id": session_id, | |
| "start_time": start_time.isoformat(), | |
| "has_file_attachment": file_content is not None, | |
| "has_conversation_context": conversation_context is not None | |
| } | |
| } | |
| # Execute workflow nodes | |
| if file_content and filename: | |
| state = merge_state(initial_state, detect_file_type_node(initial_state)) | |
| state = merge_state(state, ingest_node(state)) | |
| workflow_type = route_workflow(state) | |
| if workflow_type == "direct_output": | |
| final_state = direct_output_node(state) | |
| if output_format == "structured": | |
| yield {"type": "data", "content": final_state["result"]} | |
| yield {"type": "end", "content": ""} | |
| else: | |
| yield final_state["result"] | |
| return | |
| else: | |
| # Retrieve using ONLY the latest query | |
| state = merge_state(state, retrieve_node(state)) | |
| else: | |
| # No file: retrieve using latest query only | |
| state = merge_state(initial_state, retrieve_node(initial_state)) | |
| # Generate response with streaming | |
| # The generator can optionally use conversation_context for better responses | |
| sources_collected = None | |
| accumulated_response = "" if output_format == "gradio" else None | |
| async for partial_state in generate_node_streaming(state): | |
| if "result" in partial_state: | |
| if output_format == "structured": | |
| yield {"type": "data", "content": partial_state["result"]} | |
| else: | |
| accumulated_response += partial_state["result"] | |
| yield accumulated_response | |
| if "sources" in partial_state: | |
| sources_collected = partial_state["sources"] | |
| # Format and yield sources | |
| if sources_collected: | |
| if output_format == "structured": | |
| yield {"type": "sources", "content": sources_collected} | |
| else: | |
| sources_text = "\n\n**Sources:**\n" | |
| for i, source in enumerate(sources_collected, 1): | |
| if isinstance(source, dict): | |
| title = source.get('title', 'Unknown') | |
| link = source.get('link', '#') | |
| sources_text += f"{i}. [{title}]({link})\n" | |
| else: | |
| sources_text += f"{i}. {source}\n" | |
| accumulated_response += sources_text | |
| yield accumulated_response | |
| if output_format == "structured": | |
| yield {"type": "end", "content": ""} | |
| except Exception as e: | |
| logger.error(f"Streaming pipeline failed: {str(e)}") | |
| if output_format == "structured": | |
| yield {"type": "error", "content": f"Error: {str(e)}"} | |
| else: | |
| yield f"Error: {str(e)}" |