import os import logging import gradio as gr from fastapi import FastAPI from langserve import add_routes from langgraph.graph import StateGraph, START, END from langchain_core.runnables import RunnableLambda import uvicorn import asyncio import base64 from typing import Optional, List from utils import getconfig, build_conversation_context from nodes import ( detect_file_type_node, ingest_node, direct_output_node, retrieve_node, generate_node_streaming, route_workflow, process_query_streaming ) from models import GraphState, ChatUIInput, ChatUIFileInput config = getconfig("params.cfg") MAX_TURNS = config.getint("conversation_history", "MAX_TURNS") MAX_CHARS = config.getint("conversation_history", "MAX_CHARS") logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) #---------------------------------------- # LANGGRAPH WORKFLOW SETUP #---------------------------------------- # Workflow handles direct output caching automatically in nodes workflow = StateGraph(GraphState) # Add all nodes workflow.add_node("detect_file_type", detect_file_type_node) workflow.add_node("ingest", ingest_node) workflow.add_node("direct_output", direct_output_node) workflow.add_node("retrieve", retrieve_node) workflow.add_node("generate", generate_node_streaming) # Simple linear path - node logic handles routing workflow.add_edge(START, "detect_file_type") workflow.add_edge("detect_file_type", "ingest") # Route after ingestion based on direct output mode # Direct output mode routes to direct_output (return ingestor results) # Standard mode routes to retrieve + generate (full RAG pipeline) workflow.add_conditional_edges( "ingest", route_workflow, {"direct_output": "direct_output", "standard": "retrieve"} ) # Standard RAG path workflow.add_edge("retrieve", "generate") # Terminal nodes workflow.add_edge("generate", END) workflow.add_edge("direct_output", END) compiled_graph = workflow.compile() #---------------------------------------- # CHATUI ADAPTERS #---------------------------------------- async def chatui_adapter(data): """Text-only adapter for ChatUI with structured message support""" try: # Handle both dict and object access patterns if isinstance(data, dict): text_value = data.get('text', '') messages_value = data.get('messages', None) preprompt_value = data.get('preprompt', None) else: text_value = getattr(data, 'text', '') messages_value = getattr(data, 'messages', None) preprompt_value = getattr(data, 'preprompt', None) # Extract query # Convert dict messages to objects if needed messages = [] for msg in messages_value: if isinstance(msg, dict): messages.append(type('Message', (), { 'role': msg.get('role', 'unknown'), 'content': msg.get('content', '') })()) else: messages.append(msg) logger.info(f"Context: {messages}") # Extract latest user query user_messages = [msg for msg in messages if msg.role == 'user'] query = user_messages[-1].content if user_messages else text_value # Log conversation context logger.info(f"Processing query: {query}") logger.info(f"Total messages in conversation: {len(messages)}") logger.info(f"User messages: {len(user_messages)}, Assistant messages: {len([m for m in messages if m.role == 'assistant'])}") # Build conversation context for generation (last N turns) conversation_context = build_conversation_context(messages, max_turns=MAX_TURNS, max_chars=MAX_CHARS) logger.info(f"Context: {conversation_context}") full_response = "" sources_collected = None async for result in process_query_streaming( query=query, file_upload=None, reports_filter="", sources_filter="", subtype_filter="", year_filter="", conversation_context=conversation_context ): if isinstance(result, dict): result_type = result.get("type", "data") content = result.get("content", "") if result_type == "data": full_response += content yield content elif result_type == "sources": sources_collected = content elif result_type == "end": if sources_collected: sources_text = "\n\n**Sources:**\n" for i, source in enumerate(sources_collected, 1): sources_text += f"{i}. [{source.get('title', 'Unknown')}]({source.get('link', '#')})\n" yield sources_text elif result_type == "error": yield f"Error: {content}" else: yield str(result) await asyncio.sleep(0) except Exception as e: logger.error(f"ChatUI error: {str(e)}") logger.error(f"Full traceback:", exc_info=True) yield f"Error: {str(e)}" async def chatui_file_adapter(data): """File upload adapter for ChatUI with structured message support""" try: # Handle both dict and object access patterns if isinstance(data, dict): text_value = data.get('text', '') messages_value = data.get('messages', None) files_value = data.get('files', None) preprompt_value = data.get('preprompt', None) else: text_value = getattr(data, 'text', '') messages_value = getattr(data, 'messages', None) files_value = getattr(data, 'files', None) preprompt_value = getattr(data, 'preprompt', None) # Extract query - prefer structured messages if messages_value and len(messages_value) > 0: logger.info("✓ Using structured messages") # Convert dict messages to objects messages = [] for msg in messages_value: if isinstance(msg, dict): messages.append(type('Message', (), { 'role': msg.get('role', 'unknown'), 'content': msg.get('content', '') })()) else: messages.append(msg) user_messages = [msg for msg in messages if msg.role == 'user'] query = user_messages[-1].content if user_messages else text_value logger.info(f"Processing query: {query}") logger.info(f"Total messages in conversation: {len(messages)}") logger.info(f"User messages: {len(user_messages)}, Assistant messages: {len([m for m in messages if m.role == 'assistant'])}") conversation_context = build_conversation_context(messages, max_turns=MAX_TURNS, max_chars=MAX_CHARS) file_content = None filename = None if files_value and len(files_value) > 0: file_info = files_value[0] logger.info(f"Processing file: {file_info.get('name', 'unknown')}") if file_info.get('type') == 'base64' and file_info.get('content'): try: file_content = base64.b64decode(file_info['content']) filename = file_info.get('name', 'uploaded_file') except Exception as e: logger.error(f"Error decoding base64 file: {str(e)}") yield f"Error: Failed to decode uploaded file - {str(e)}" return sources_collected = None async for result in process_query_streaming( query=query, file_content=file_content, filename=filename, reports_filter="", sources_filter="", subtype_filter="", year_filter="", output_format="structured", conversation_context=conversation_context ): if isinstance(result, dict): result_type = result.get("type", "data") content = result.get("content", "") if result_type == "data": yield content elif result_type == "sources": sources_collected = content elif result_type == "end": if sources_collected: sources_text = "\n\n**Sources:**\n" for i, source in enumerate(sources_collected, 1): if isinstance(source, dict): title = source.get('title', 'Unknown') link = source.get('link', '#') sources_text += f"{i}. [{title}]({link})\n" else: sources_text += f"{i}. {source}\n" yield sources_text elif result_type == "error": yield f"Error: {content}" else: yield str(result) await asyncio.sleep(0) except Exception as e: logger.error(f"ChatUI file adapter error: {str(e)}") yield f"Error: {str(e)}" #---------------------------------------- # FASTAPI SETUP - for future use #---------------------------------------- app = FastAPI(title="ChatFed Orchestrator", version="1.0.0") @app.get("/health") async def health_check(): return {"status": "healthy"} @app.get("/") async def root(): return { "message": "ChatFed Orchestrator API", "endpoints": { "health": "/health", "chatfed-ui-stream": "/chatfed-ui-stream (LangServe)", "chatfed-with-file-stream": "/chatfed-with-file-stream (LangServe)", "gradio": "/gradio" } } #---------------------------------------- # LANGSERVE ROUTES - endpoints for ChatUI #---------------------------------------- # Text-only endpoint add_routes( app, RunnableLambda(chatui_adapter), path="/chatfed-ui-stream", input_type=ChatUIInput, output_type=str, enable_feedback_endpoint=True, enable_public_trace_link_endpoint=True, ) # File upload endpoint add_routes( app, RunnableLambda(chatui_file_adapter), path="/chatfed-with-file-stream", input_type=ChatUIFileInput, output_type=str, enable_feedback_endpoint=True, enable_public_trace_link_endpoint=True, ) #---------------------------------------- # GRADIO INTERFACE - for local testing # ACCESS: https://[ORG_NAME]-[SPACE_NAME].hf.space/gradio/ #---------------------------------------- def create_gradio_interface(): with gr.Blocks(title="ChatFed Orchestrator") as demo: gr.Markdown("# ChatFed Orchestrator") gr.Markdown("Upload documents (PDF/DOCX/GeoJSON) alongside your queries for enhanced context.") with gr.Row(): with gr.Column(): query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...") file_input = gr.File( label="Upload Document (PDF/DOCX/GeoJSON)", file_types=[".pdf", ".docx", ".geojson", ".json"] ) with gr.Accordion("Filters (Optional)", open=False): reports_filter = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports") sources_filter = gr.Textbox(label="Sources Filter", placeholder="e.g., internal") subtype_filter = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial") year_filter = gr.Textbox(label="Year Filter", placeholder="e.g., 2024") submit_btn = gr.Button("Submit", variant="primary") with gr.Column(): output = gr.Textbox(label="Response", lines=15, show_copy_button=True) async def gradio_handler(query, file, reports, sources, subtype, year): """Handler for Gradio interface""" result = "" async for chunk in process_query_streaming( query=query, file_upload=file, reports_filter=reports, sources_filter=sources, subtype_filter=subtype, year_filter=year, output_format="gradio" ): result = chunk yield result submit_btn.click( fn=gradio_handler, inputs=[query_input, file_input, reports_filter, sources_filter, subtype_filter, year_filter], outputs=output, ) return demo #---------------------------------------- if __name__ == "__main__": demo = create_gradio_interface() app = gr.mount_gradio_app(app, demo, path="/gradio") host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", "7860")) logger.info(f"Starting ChatFed Orchestrator on {host}:{port}") logger.info(f"Gradio UI: http://{host}:{port}/gradio") logger.info(f"API Docs: http://{host}:{port}/docs") uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)