mtyrrell's picture
sources
f852f01
import gradio as gr
import asyncio
import logging
import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from utils.generator import generate_streaming, generate
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('app.log')
]
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------
# FastAPI app for ChatUI endpoints
# ---------------------------------------------------------------------
app = FastAPI(title="ChatFed Generator", version="1.0.0")
@app.post("/generate")
async def generate_endpoint(request: Request):
"""
Non-streaming generation endpoint for ChatUI format.
Expected request body:
{
"query": "user question",
"context": [...] // list of retrieval results
}
Returns ChatUI format:
{
"answer": "response with citations [1][2]",
"sources": [{"link": "doc://...", "title": "..."}]
}
"""
try:
body = await request.json()
query = body.get("query", "")
context = body.get("context", [])
result = await generate(query, context, chatui_format=True)
return result
except Exception as e:
logger.exception("Generation endpoint failed")
return {"error": str(e)}
@app.post("/generate/stream")
async def generate_stream_endpoint(request: Request):
"""
Streaming generation endpoint for ChatUI format.
Expected request body:
{
"query": "user question",
"context": [...] // list of retrieval results
}
Returns Server-Sent Events in ChatUI format:
event: data
data: "response chunk"
event: sources
data: {"sources": [...]}
event: end
"""
try:
body = await request.json()
query = body.get("query", "")
context = body.get("context", [])
async def event_stream():
async for event in generate_streaming(query, context, chatui_format=True):
event_type = event["event"]
event_data = event["data"]
if event_type == "data":
yield f"event: data\ndata: {json.dumps(event_data)}\n\n"
elif event_type == "sources":
yield f"event: sources\ndata: {json.dumps(event_data)}\n\n"
elif event_type == "end":
yield f"event: end\ndata: {{}}\n\n"
elif event_type == "error":
yield f"event: error\ndata: {json.dumps(event_data)}\n\n"
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
}
)
except Exception as e:
logger.exception("Streaming endpoint failed")
async def error_stream():
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
error_stream(),
media_type="text/event-stream"
)
# ---------------------------------------------------------------------
# Wrapper function to handle async streaming for Gradio
# ---------------------------------------------------------------------
def generate_streaming_wrapper(query: str, context: str):
"""Wrapper to convert async generator to sync generator for Gradio"""
logger.info(f"Starting generation request - Query length: {len(query)}, Context length: {len(context)}")
async def _async_generator():
async for chunk in generate_streaming(query, context, chatui_format=False):
yield chunk
# Create a new event loop for this thread
try:
loop = asyncio.get_event_loop()
logger.debug("Using existing event loop")
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
logger.debug("Created new event loop")
# Convert async generator to sync generator
async_gen = _async_generator()
# Accumulate chunks for Gradio streaming
accumulated_text = ""
chunk_count = 0
while True:
try:
chunk = loop.run_until_complete(async_gen.__anext__())
accumulated_text += chunk
chunk_count += 1
yield accumulated_text # Yield the accumulated text, not just the chunk
except StopAsyncIteration:
logger.info(f"Generation completed - Total chunks: {chunk_count}, Final text length: {len(accumulated_text)}")
break
# ---------------------------------------------------------------------
# Gradio Interface with MCP support and streaming
# ---------------------------------------------------------------------
logger.info("Initializing Gradio interface")
ui = gr.Interface(
fn=generate_streaming_wrapper, # Use streaming wrapper function
inputs=[
gr.Textbox(
label="Query",
lines=2,
placeholder="Enter query here",
info="The query to search for in the vector database"
),
gr.Textbox(
label="Context",
lines=8,
placeholder="Paste relevant context here",
info="Provide the context/documents to use for answering. The API expects a list of dictionaries, but the UI should except anything"
),
],
outputs=gr.Textbox(
label="Generated Answer",
lines=6,
show_copy_button=True
),
title="ChatFed Generation Module",
description="Ask questions based on provided context. Intended for use in RAG pipelines as an MCP server with other ChatFed modules (i.e. context supplied by semantic retriever service).",
api_name="generate"
)
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, ui, path="/gradio")
# Launch with MCP server enabled
if __name__ == "__main__":
import uvicorn
logger.info("Starting ChatFed Generation Module server")
logger.info("FastAPI server will be available at http://0.0.0.0:7860")
logger.info("Gradio UI will be available at http://0.0.0.0:7860/gradio")
logger.info("ChatUI endpoints: /generate (non-streaming), /generate/stream (streaming)")
uvicorn.run(app, host="0.0.0.0", port=7860)