Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os | |
| import logging | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from langserve import add_routes | |
| from langgraph.graph import StateGraph, START, END | |
| from langchain_core.runnables import RunnableLambda | |
| import uvicorn | |
| import asyncio | |
| import base64 | |
| from typing import Optional, List | |
| from utils import getconfig, build_conversation_context | |
| from nodes import ( | |
| detect_file_type_node, ingest_node, direct_output_node, | |
| retrieve_node, generate_node_streaming, route_workflow, | |
| process_query_streaming | |
| ) | |
| from models import GraphState, ChatUIInput, ChatUIFileInput | |
| config = getconfig("params.cfg") | |
| MAX_TURNS = config.getint("conversation_history", "MAX_TURNS") | |
| MAX_CHARS = config.getint("conversation_history", "MAX_CHARS") | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| #---------------------------------------- | |
| # LANGGRAPH WORKFLOW SETUP | |
| #---------------------------------------- | |
| # Workflow handles direct output caching automatically in nodes | |
| workflow = StateGraph(GraphState) | |
| # Add all nodes | |
| workflow.add_node("detect_file_type", detect_file_type_node) | |
| workflow.add_node("ingest", ingest_node) | |
| workflow.add_node("direct_output", direct_output_node) | |
| workflow.add_node("retrieve", retrieve_node) | |
| workflow.add_node("generate", generate_node_streaming) | |
| # Simple linear path - node logic handles routing | |
| workflow.add_edge(START, "detect_file_type") | |
| workflow.add_edge("detect_file_type", "ingest") | |
| # Route after ingestion based on direct output mode | |
| # Direct output mode routes to direct_output (return ingestor results) | |
| # Standard mode routes to retrieve + generate (full RAG pipeline) | |
| workflow.add_conditional_edges( | |
| "ingest", | |
| route_workflow, | |
| {"direct_output": "direct_output", "standard": "retrieve"} | |
| ) | |
| # Standard RAG path | |
| workflow.add_edge("retrieve", "generate") | |
| # Terminal nodes | |
| workflow.add_edge("generate", END) | |
| workflow.add_edge("direct_output", END) | |
| compiled_graph = workflow.compile() | |
| #---------------------------------------- | |
| # CHATUI ADAPTERS | |
| #---------------------------------------- | |
| async def chatui_adapter(data): | |
| """Text-only adapter for ChatUI with structured message support""" | |
| try: | |
| # Handle both dict and object access patterns | |
| if isinstance(data, dict): | |
| text_value = data.get('text', '') | |
| messages_value = data.get('messages', None) | |
| preprompt_value = data.get('preprompt', None) | |
| else: | |
| text_value = getattr(data, 'text', '') | |
| messages_value = getattr(data, 'messages', None) | |
| preprompt_value = getattr(data, 'preprompt', None) | |
| # Extract query | |
| # Convert dict messages to objects if needed | |
| messages = [] | |
| for msg in messages_value: | |
| if isinstance(msg, dict): | |
| messages.append(type('Message', (), { | |
| 'role': msg.get('role', 'unknown'), | |
| 'content': msg.get('content', '') | |
| })()) | |
| else: | |
| messages.append(msg) | |
| logger.info(f"Context: {messages}") | |
| # Extract latest user query | |
| user_messages = [msg for msg in messages if msg.role == 'user'] | |
| query = user_messages[-1].content if user_messages else text_value | |
| # Log conversation context | |
| logger.info(f"Processing query: {query}") | |
| logger.info(f"Total messages in conversation: {len(messages)}") | |
| logger.info(f"User messages: {len(user_messages)}, Assistant messages: {len([m for m in messages if m.role == 'assistant'])}") | |
| # Build conversation context for generation (last N turns) | |
| conversation_context = build_conversation_context(messages, max_turns=MAX_TURNS, max_chars=MAX_CHARS) | |
| logger.info(f"Context: {conversation_context}") | |
| full_response = "" | |
| sources_collected = None | |
| async for result in process_query_streaming( | |
| query=query, | |
| file_upload=None, | |
| reports_filter="", | |
| sources_filter="", | |
| subtype_filter="", | |
| year_filter="", | |
| conversation_context=conversation_context | |
| ): | |
| if isinstance(result, dict): | |
| result_type = result.get("type", "data") | |
| content = result.get("content", "") | |
| if result_type == "data": | |
| full_response += content | |
| yield content | |
| elif result_type == "sources": | |
| sources_collected = content | |
| elif result_type == "end": | |
| if sources_collected: | |
| sources_text = "\n\n**Sources:**\n" | |
| for i, source in enumerate(sources_collected, 1): | |
| sources_text += f"{i}. [{source.get('title', 'Unknown')}]({source.get('link', '#')})\n" | |
| yield sources_text | |
| elif result_type == "error": | |
| yield f"Error: {content}" | |
| else: | |
| yield str(result) | |
| await asyncio.sleep(0) | |
| except Exception as e: | |
| logger.error(f"ChatUI error: {str(e)}") | |
| logger.error(f"Full traceback:", exc_info=True) | |
| yield f"Error: {str(e)}" | |
| async def chatui_file_adapter(data): | |
| """File upload adapter for ChatUI with structured message support""" | |
| try: | |
| # Handle both dict and object access patterns | |
| if isinstance(data, dict): | |
| text_value = data.get('text', '') | |
| messages_value = data.get('messages', None) | |
| files_value = data.get('files', None) | |
| preprompt_value = data.get('preprompt', None) | |
| else: | |
| text_value = getattr(data, 'text', '') | |
| messages_value = getattr(data, 'messages', None) | |
| files_value = getattr(data, 'files', None) | |
| preprompt_value = getattr(data, 'preprompt', None) | |
| # Extract query - prefer structured messages | |
| if messages_value and len(messages_value) > 0: | |
| logger.info("✓ Using structured messages") | |
| # Convert dict messages to objects | |
| messages = [] | |
| for msg in messages_value: | |
| if isinstance(msg, dict): | |
| messages.append(type('Message', (), { | |
| 'role': msg.get('role', 'unknown'), | |
| 'content': msg.get('content', '') | |
| })()) | |
| else: | |
| messages.append(msg) | |
| user_messages = [msg for msg in messages if msg.role == 'user'] | |
| query = user_messages[-1].content if user_messages else text_value | |
| logger.info(f"Processing query: {query}") | |
| logger.info(f"Total messages in conversation: {len(messages)}") | |
| logger.info(f"User messages: {len(user_messages)}, Assistant messages: {len([m for m in messages if m.role == 'assistant'])}") | |
| conversation_context = build_conversation_context(messages, max_turns=MAX_TURNS, max_chars=MAX_CHARS) | |
| file_content = None | |
| filename = None | |
| if files_value and len(files_value) > 0: | |
| file_info = files_value[0] | |
| logger.info(f"Processing file: {file_info.get('name', 'unknown')}") | |
| if file_info.get('type') == 'base64' and file_info.get('content'): | |
| try: | |
| file_content = base64.b64decode(file_info['content']) | |
| filename = file_info.get('name', 'uploaded_file') | |
| except Exception as e: | |
| logger.error(f"Error decoding base64 file: {str(e)}") | |
| yield f"Error: Failed to decode uploaded file - {str(e)}" | |
| return | |
| sources_collected = None | |
| async for result in process_query_streaming( | |
| query=query, | |
| file_content=file_content, | |
| filename=filename, | |
| reports_filter="", | |
| sources_filter="", | |
| subtype_filter="", | |
| year_filter="", | |
| output_format="structured", | |
| conversation_context=conversation_context | |
| ): | |
| if isinstance(result, dict): | |
| result_type = result.get("type", "data") | |
| content = result.get("content", "") | |
| if result_type == "data": | |
| yield content | |
| elif result_type == "sources": | |
| sources_collected = content | |
| elif result_type == "end": | |
| if sources_collected: | |
| sources_text = "\n\n**Sources:**\n" | |
| for i, source in enumerate(sources_collected, 1): | |
| if isinstance(source, dict): | |
| title = source.get('title', 'Unknown') | |
| link = source.get('link', '#') | |
| sources_text += f"{i}. [{title}]({link})\n" | |
| else: | |
| sources_text += f"{i}. {source}\n" | |
| yield sources_text | |
| elif result_type == "error": | |
| yield f"Error: {content}" | |
| else: | |
| yield str(result) | |
| await asyncio.sleep(0) | |
| except Exception as e: | |
| logger.error(f"ChatUI file adapter error: {str(e)}") | |
| yield f"Error: {str(e)}" | |
| #---------------------------------------- | |
| # FASTAPI SETUP - for future use | |
| #---------------------------------------- | |
| app = FastAPI(title="ChatFed Orchestrator", version="1.0.0") | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| async def root(): | |
| return { | |
| "message": "ChatFed Orchestrator API", | |
| "endpoints": { | |
| "health": "/health", | |
| "chatfed-ui-stream": "/chatfed-ui-stream (LangServe)", | |
| "chatfed-with-file-stream": "/chatfed-with-file-stream (LangServe)", | |
| "gradio": "/gradio" | |
| } | |
| } | |
| #---------------------------------------- | |
| # LANGSERVE ROUTES - endpoints for ChatUI | |
| #---------------------------------------- | |
| # Text-only endpoint | |
| add_routes( | |
| app, | |
| RunnableLambda(chatui_adapter), | |
| path="/chatfed-ui-stream", | |
| input_type=ChatUIInput, | |
| output_type=str, | |
| enable_feedback_endpoint=True, | |
| enable_public_trace_link_endpoint=True, | |
| ) | |
| # File upload endpoint | |
| add_routes( | |
| app, | |
| RunnableLambda(chatui_file_adapter), | |
| path="/chatfed-with-file-stream", | |
| input_type=ChatUIFileInput, | |
| output_type=str, | |
| enable_feedback_endpoint=True, | |
| enable_public_trace_link_endpoint=True, | |
| ) | |
| #---------------------------------------- | |
| # GRADIO INTERFACE - for local testing | |
| # ACCESS: https://[ORG_NAME]-[SPACE_NAME].hf.space/gradio/ | |
| #---------------------------------------- | |
| def create_gradio_interface(): | |
| with gr.Blocks(title="ChatFed Orchestrator") as demo: | |
| gr.Markdown("# ChatFed Orchestrator") | |
| gr.Markdown("Upload documents (PDF/DOCX/GeoJSON) alongside your queries for enhanced context.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...") | |
| file_input = gr.File( | |
| label="Upload Document (PDF/DOCX/GeoJSON)", | |
| file_types=[".pdf", ".docx", ".geojson", ".json"] | |
| ) | |
| with gr.Accordion("Filters (Optional)", open=False): | |
| reports_filter = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports") | |
| sources_filter = gr.Textbox(label="Sources Filter", placeholder="e.g., internal") | |
| subtype_filter = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial") | |
| year_filter = gr.Textbox(label="Year Filter", placeholder="e.g., 2024") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Response", lines=15, show_copy_button=True) | |
| async def gradio_handler(query, file, reports, sources, subtype, year): | |
| """Handler for Gradio interface""" | |
| result = "" | |
| async for chunk in process_query_streaming( | |
| query=query, | |
| file_upload=file, | |
| reports_filter=reports, | |
| sources_filter=sources, | |
| subtype_filter=subtype, | |
| year_filter=year, | |
| output_format="gradio" | |
| ): | |
| result = chunk | |
| yield result | |
| submit_btn.click( | |
| fn=gradio_handler, | |
| inputs=[query_input, file_input, reports_filter, sources_filter, subtype_filter, year_filter], | |
| outputs=output, | |
| ) | |
| return demo | |
| #---------------------------------------- | |
| if __name__ == "__main__": | |
| demo = create_gradio_interface() | |
| app = gr.mount_gradio_app(app, demo, path="/gradio") | |
| host = os.getenv("HOST", "0.0.0.0") | |
| port = int(os.getenv("PORT", "7860")) | |
| logger.info(f"Starting ChatFed Orchestrator on {host}:{port}") | |
| logger.info(f"Gradio UI: http://{host}:{port}/gradio") | |
| logger.info(f"API Docs: http://{host}:{port}/docs") | |
| uvicorn.run(app, host=host, port=port, log_level="info", access_log=True) |