mtyrrell's picture
cleanup of all cache-related code/comments; README revision
99fffc4
import os
import logging
import gradio as gr
from fastapi import FastAPI
from langserve import add_routes
from langgraph.graph import StateGraph, START, END
from langchain_core.runnables import RunnableLambda
import uvicorn
import asyncio
import base64
from typing import Optional, List
from utils import getconfig, build_conversation_context
from nodes import (
detect_file_type_node, ingest_node, direct_output_node,
retrieve_node, generate_node_streaming, route_workflow,
process_query_streaming
)
from models import GraphState, ChatUIInput, ChatUIFileInput
config = getconfig("params.cfg")
MAX_TURNS = config.getint("conversation_history", "MAX_TURNS")
MAX_CHARS = config.getint("conversation_history", "MAX_CHARS")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
#----------------------------------------
# LANGGRAPH WORKFLOW SETUP
#----------------------------------------
# Workflow handles direct output caching automatically in nodes
workflow = StateGraph(GraphState)
# Add all nodes
workflow.add_node("detect_file_type", detect_file_type_node)
workflow.add_node("ingest", ingest_node)
workflow.add_node("direct_output", direct_output_node)
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("generate", generate_node_streaming)
# Simple linear path - node logic handles routing
workflow.add_edge(START, "detect_file_type")
workflow.add_edge("detect_file_type", "ingest")
# Route after ingestion based on direct output mode
# Direct output mode routes to direct_output (return ingestor results)
# Standard mode routes to retrieve + generate (full RAG pipeline)
workflow.add_conditional_edges(
"ingest",
route_workflow,
{"direct_output": "direct_output", "standard": "retrieve"}
)
# Standard RAG path
workflow.add_edge("retrieve", "generate")
# Terminal nodes
workflow.add_edge("generate", END)
workflow.add_edge("direct_output", END)
compiled_graph = workflow.compile()
#----------------------------------------
# CHATUI ADAPTERS
#----------------------------------------
async def chatui_adapter(data):
"""Text-only adapter for ChatUI with structured message support"""
try:
# Handle both dict and object access patterns
if isinstance(data, dict):
text_value = data.get('text', '')
messages_value = data.get('messages', None)
preprompt_value = data.get('preprompt', None)
else:
text_value = getattr(data, 'text', '')
messages_value = getattr(data, 'messages', None)
preprompt_value = getattr(data, 'preprompt', None)
# Extract query
# Convert dict messages to objects if needed
messages = []
for msg in messages_value:
if isinstance(msg, dict):
messages.append(type('Message', (), {
'role': msg.get('role', 'unknown'),
'content': msg.get('content', '')
})())
else:
messages.append(msg)
logger.info(f"Context: {messages}")
# Extract latest user query
user_messages = [msg for msg in messages if msg.role == 'user']
query = user_messages[-1].content if user_messages else text_value
# Log conversation context
logger.info(f"Processing query: {query}")
logger.info(f"Total messages in conversation: {len(messages)}")
logger.info(f"User messages: {len(user_messages)}, Assistant messages: {len([m for m in messages if m.role == 'assistant'])}")
# Build conversation context for generation (last N turns)
conversation_context = build_conversation_context(messages, max_turns=MAX_TURNS, max_chars=MAX_CHARS)
logger.info(f"Context: {conversation_context}")
full_response = ""
sources_collected = None
async for result in process_query_streaming(
query=query,
file_upload=None,
reports_filter="",
sources_filter="",
subtype_filter="",
year_filter="",
conversation_context=conversation_context
):
if isinstance(result, dict):
result_type = result.get("type", "data")
content = result.get("content", "")
if result_type == "data":
full_response += content
yield content
elif result_type == "sources":
sources_collected = content
elif result_type == "end":
if sources_collected:
sources_text = "\n\n**Sources:**\n"
for i, source in enumerate(sources_collected, 1):
sources_text += f"{i}. [{source.get('title', 'Unknown')}]({source.get('link', '#')})\n"
yield sources_text
elif result_type == "error":
yield f"Error: {content}"
else:
yield str(result)
await asyncio.sleep(0)
except Exception as e:
logger.error(f"ChatUI error: {str(e)}")
logger.error(f"Full traceback:", exc_info=True)
yield f"Error: {str(e)}"
async def chatui_file_adapter(data):
"""File upload adapter for ChatUI with structured message support"""
try:
# Handle both dict and object access patterns
if isinstance(data, dict):
text_value = data.get('text', '')
messages_value = data.get('messages', None)
files_value = data.get('files', None)
preprompt_value = data.get('preprompt', None)
else:
text_value = getattr(data, 'text', '')
messages_value = getattr(data, 'messages', None)
files_value = getattr(data, 'files', None)
preprompt_value = getattr(data, 'preprompt', None)
# Extract query - prefer structured messages
if messages_value and len(messages_value) > 0:
logger.info("✓ Using structured messages")
# Convert dict messages to objects
messages = []
for msg in messages_value:
if isinstance(msg, dict):
messages.append(type('Message', (), {
'role': msg.get('role', 'unknown'),
'content': msg.get('content', '')
})())
else:
messages.append(msg)
user_messages = [msg for msg in messages if msg.role == 'user']
query = user_messages[-1].content if user_messages else text_value
logger.info(f"Processing query: {query}")
logger.info(f"Total messages in conversation: {len(messages)}")
logger.info(f"User messages: {len(user_messages)}, Assistant messages: {len([m for m in messages if m.role == 'assistant'])}")
conversation_context = build_conversation_context(messages, max_turns=MAX_TURNS, max_chars=MAX_CHARS)
file_content = None
filename = None
if files_value and len(files_value) > 0:
file_info = files_value[0]
logger.info(f"Processing file: {file_info.get('name', 'unknown')}")
if file_info.get('type') == 'base64' and file_info.get('content'):
try:
file_content = base64.b64decode(file_info['content'])
filename = file_info.get('name', 'uploaded_file')
except Exception as e:
logger.error(f"Error decoding base64 file: {str(e)}")
yield f"Error: Failed to decode uploaded file - {str(e)}"
return
sources_collected = None
async for result in process_query_streaming(
query=query,
file_content=file_content,
filename=filename,
reports_filter="",
sources_filter="",
subtype_filter="",
year_filter="",
output_format="structured",
conversation_context=conversation_context
):
if isinstance(result, dict):
result_type = result.get("type", "data")
content = result.get("content", "")
if result_type == "data":
yield content
elif result_type == "sources":
sources_collected = content
elif result_type == "end":
if sources_collected:
sources_text = "\n\n**Sources:**\n"
for i, source in enumerate(sources_collected, 1):
if isinstance(source, dict):
title = source.get('title', 'Unknown')
link = source.get('link', '#')
sources_text += f"{i}. [{title}]({link})\n"
else:
sources_text += f"{i}. {source}\n"
yield sources_text
elif result_type == "error":
yield f"Error: {content}"
else:
yield str(result)
await asyncio.sleep(0)
except Exception as e:
logger.error(f"ChatUI file adapter error: {str(e)}")
yield f"Error: {str(e)}"
#----------------------------------------
# FASTAPI SETUP - for future use
#----------------------------------------
app = FastAPI(title="ChatFed Orchestrator", version="1.0.0")
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.get("/")
async def root():
return {
"message": "ChatFed Orchestrator API",
"endpoints": {
"health": "/health",
"chatfed-ui-stream": "/chatfed-ui-stream (LangServe)",
"chatfed-with-file-stream": "/chatfed-with-file-stream (LangServe)",
"gradio": "/gradio"
}
}
#----------------------------------------
# LANGSERVE ROUTES - endpoints for ChatUI
#----------------------------------------
# Text-only endpoint
add_routes(
app,
RunnableLambda(chatui_adapter),
path="/chatfed-ui-stream",
input_type=ChatUIInput,
output_type=str,
enable_feedback_endpoint=True,
enable_public_trace_link_endpoint=True,
)
# File upload endpoint
add_routes(
app,
RunnableLambda(chatui_file_adapter),
path="/chatfed-with-file-stream",
input_type=ChatUIFileInput,
output_type=str,
enable_feedback_endpoint=True,
enable_public_trace_link_endpoint=True,
)
#----------------------------------------
# GRADIO INTERFACE - for local testing
# ACCESS: https://[ORG_NAME]-[SPACE_NAME].hf.space/gradio/
#----------------------------------------
def create_gradio_interface():
with gr.Blocks(title="ChatFed Orchestrator") as demo:
gr.Markdown("# ChatFed Orchestrator")
gr.Markdown("Upload documents (PDF/DOCX/GeoJSON) alongside your queries for enhanced context.")
with gr.Row():
with gr.Column():
query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
file_input = gr.File(
label="Upload Document (PDF/DOCX/GeoJSON)",
file_types=[".pdf", ".docx", ".geojson", ".json"]
)
with gr.Accordion("Filters (Optional)", open=False):
reports_filter = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
sources_filter = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
subtype_filter = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
year_filter = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
output = gr.Textbox(label="Response", lines=15, show_copy_button=True)
async def gradio_handler(query, file, reports, sources, subtype, year):
"""Handler for Gradio interface"""
result = ""
async for chunk in process_query_streaming(
query=query,
file_upload=file,
reports_filter=reports,
sources_filter=sources,
subtype_filter=subtype,
year_filter=year,
output_format="gradio"
):
result = chunk
yield result
submit_btn.click(
fn=gradio_handler,
inputs=[query_input, file_input, reports_filter, sources_filter, subtype_filter, year_filter],
outputs=output,
)
return demo
#----------------------------------------
if __name__ == "__main__":
demo = create_gradio_interface()
app = gr.mount_gradio_app(app, demo, path="/gradio")
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "7860"))
logger.info(f"Starting ChatFed Orchestrator on {host}:{port}")
logger.info(f"Gradio UI: http://{host}:{port}/gradio")
logger.info(f"API Docs: http://{host}:{port}/docs")
uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)