import gradio as gr import asyncio import logging import json from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from utils.generator import generate_streaming, generate # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(), logging.FileHandler('app.log') ] ) logger = logging.getLogger(__name__) # --------------------------------------------------------------------- # FastAPI app for ChatUI endpoints # --------------------------------------------------------------------- app = FastAPI(title="ChatFed Generator", version="1.0.0") @app.post("/generate") async def generate_endpoint(request: Request): """ Non-streaming generation endpoint for ChatUI format. Expected request body: { "query": "user question", "context": [...] // list of retrieval results } Returns ChatUI format: { "answer": "response with citations [1][2]", "sources": [{"link": "doc://...", "title": "..."}] } """ try: body = await request.json() query = body.get("query", "") context = body.get("context", []) result = await generate(query, context, chatui_format=True) return result except Exception as e: logger.exception("Generation endpoint failed") return {"error": str(e)} @app.post("/generate/stream") async def generate_stream_endpoint(request: Request): """ Streaming generation endpoint for ChatUI format. Expected request body: { "query": "user question", "context": [...] // list of retrieval results } Returns Server-Sent Events in ChatUI format: event: data data: "response chunk" event: sources data: {"sources": [...]} event: end """ try: body = await request.json() query = body.get("query", "") context = body.get("context", []) async def event_stream(): async for event in generate_streaming(query, context, chatui_format=True): event_type = event["event"] event_data = event["data"] if event_type == "data": yield f"event: data\ndata: {json.dumps(event_data)}\n\n" elif event_type == "sources": yield f"event: sources\ndata: {json.dumps(event_data)}\n\n" elif event_type == "end": yield f"event: end\ndata: {{}}\n\n" elif event_type == "error": yield f"event: error\ndata: {json.dumps(event_data)}\n\n" return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*", } ) except Exception as e: logger.exception("Streaming endpoint failed") async def error_stream(): yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" return StreamingResponse( error_stream(), media_type="text/event-stream" ) # --------------------------------------------------------------------- # Wrapper function to handle async streaming for Gradio # --------------------------------------------------------------------- def generate_streaming_wrapper(query: str, context: str): """Wrapper to convert async generator to sync generator for Gradio""" logger.info(f"Starting generation request - Query length: {len(query)}, Context length: {len(context)}") async def _async_generator(): async for chunk in generate_streaming(query, context, chatui_format=False): yield chunk # Create a new event loop for this thread try: loop = asyncio.get_event_loop() logger.debug("Using existing event loop") except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) logger.debug("Created new event loop") # Convert async generator to sync generator async_gen = _async_generator() # Accumulate chunks for Gradio streaming accumulated_text = "" chunk_count = 0 while True: try: chunk = loop.run_until_complete(async_gen.__anext__()) accumulated_text += chunk chunk_count += 1 yield accumulated_text # Yield the accumulated text, not just the chunk except StopAsyncIteration: logger.info(f"Generation completed - Total chunks: {chunk_count}, Final text length: {len(accumulated_text)}") break # --------------------------------------------------------------------- # Gradio Interface with MCP support and streaming # --------------------------------------------------------------------- logger.info("Initializing Gradio interface") ui = gr.Interface( fn=generate_streaming_wrapper, # Use streaming wrapper function inputs=[ gr.Textbox( label="Query", lines=2, placeholder="Enter query here", info="The query to search for in the vector database" ), gr.Textbox( label="Context", lines=8, placeholder="Paste relevant context here", info="Provide the context/documents to use for answering. The API expects a list of dictionaries, but the UI should except anything" ), ], outputs=gr.Textbox( label="Generated Answer", lines=6, show_copy_button=True ), title="ChatFed Generation Module", description="Ask questions based on provided context. Intended for use in RAG pipelines as an MCP server with other ChatFed modules (i.e. context supplied by semantic retriever service).", api_name="generate" ) # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, ui, path="/gradio") # Launch with MCP server enabled if __name__ == "__main__": import uvicorn logger.info("Starting ChatFed Generation Module server") logger.info("FastAPI server will be available at http://0.0.0.0:7860") logger.info("Gradio UI will be available at http://0.0.0.0:7860/gradio") logger.info("ChatUI endpoints: /generate (non-streaming), /generate/stream (streaming)") uvicorn.run(app, host="0.0.0.0", port=7860)