Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import gradio as gr | |
| import asyncio | |
| import logging | |
| import json | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import StreamingResponse | |
| from utils.generator import generate_streaming, generate | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(), | |
| logging.FileHandler('app.log') | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------- | |
| # FastAPI app for ChatUI endpoints | |
| # --------------------------------------------------------------------- | |
| app = FastAPI(title="ChatFed Generator", version="1.0.0") | |
| async def generate_endpoint(request: Request): | |
| """ | |
| Non-streaming generation endpoint for ChatUI format. | |
| Expected request body: | |
| { | |
| "query": "user question", | |
| "context": [...] // list of retrieval results | |
| } | |
| Returns ChatUI format: | |
| { | |
| "answer": "response with citations [1][2]", | |
| "sources": [{"link": "doc://...", "title": "..."}] | |
| } | |
| """ | |
| try: | |
| body = await request.json() | |
| query = body.get("query", "") | |
| context = body.get("context", []) | |
| result = await generate(query, context, chatui_format=True) | |
| return result | |
| except Exception as e: | |
| logger.exception("Generation endpoint failed") | |
| return {"error": str(e)} | |
| async def generate_stream_endpoint(request: Request): | |
| """ | |
| Streaming generation endpoint for ChatUI format. | |
| Expected request body: | |
| { | |
| "query": "user question", | |
| "context": [...] // list of retrieval results | |
| } | |
| Returns Server-Sent Events in ChatUI format: | |
| event: data | |
| data: "response chunk" | |
| event: sources | |
| data: {"sources": [...]} | |
| event: end | |
| """ | |
| try: | |
| body = await request.json() | |
| query = body.get("query", "") | |
| context = body.get("context", []) | |
| async def event_stream(): | |
| async for event in generate_streaming(query, context, chatui_format=True): | |
| event_type = event["event"] | |
| event_data = event["data"] | |
| if event_type == "data": | |
| yield f"event: data\ndata: {json.dumps(event_data)}\n\n" | |
| elif event_type == "sources": | |
| yield f"event: sources\ndata: {json.dumps(event_data)}\n\n" | |
| elif event_type == "end": | |
| yield f"event: end\ndata: {{}}\n\n" | |
| elif event_type == "error": | |
| yield f"event: error\ndata: {json.dumps(event_data)}\n\n" | |
| return StreamingResponse( | |
| event_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Headers": "*", | |
| } | |
| ) | |
| except Exception as e: | |
| logger.exception("Streaming endpoint failed") | |
| async def error_stream(): | |
| yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" | |
| return StreamingResponse( | |
| error_stream(), | |
| media_type="text/event-stream" | |
| ) | |
| # --------------------------------------------------------------------- | |
| # Wrapper function to handle async streaming for Gradio | |
| # --------------------------------------------------------------------- | |
| def generate_streaming_wrapper(query: str, context: str): | |
| """Wrapper to convert async generator to sync generator for Gradio""" | |
| logger.info(f"Starting generation request - Query length: {len(query)}, Context length: {len(context)}") | |
| async def _async_generator(): | |
| async for chunk in generate_streaming(query, context, chatui_format=False): | |
| yield chunk | |
| # Create a new event loop for this thread | |
| try: | |
| loop = asyncio.get_event_loop() | |
| logger.debug("Using existing event loop") | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| logger.debug("Created new event loop") | |
| # Convert async generator to sync generator | |
| async_gen = _async_generator() | |
| # Accumulate chunks for Gradio streaming | |
| accumulated_text = "" | |
| chunk_count = 0 | |
| while True: | |
| try: | |
| chunk = loop.run_until_complete(async_gen.__anext__()) | |
| accumulated_text += chunk | |
| chunk_count += 1 | |
| yield accumulated_text # Yield the accumulated text, not just the chunk | |
| except StopAsyncIteration: | |
| logger.info(f"Generation completed - Total chunks: {chunk_count}, Final text length: {len(accumulated_text)}") | |
| break | |
| # --------------------------------------------------------------------- | |
| # Gradio Interface with MCP support and streaming | |
| # --------------------------------------------------------------------- | |
| logger.info("Initializing Gradio interface") | |
| ui = gr.Interface( | |
| fn=generate_streaming_wrapper, # Use streaming wrapper function | |
| inputs=[ | |
| gr.Textbox( | |
| label="Query", | |
| lines=2, | |
| placeholder="Enter query here", | |
| info="The query to search for in the vector database" | |
| ), | |
| gr.Textbox( | |
| label="Context", | |
| lines=8, | |
| placeholder="Paste relevant context here", | |
| info="Provide the context/documents to use for answering. The API expects a list of dictionaries, but the UI should except anything" | |
| ), | |
| ], | |
| outputs=gr.Textbox( | |
| label="Generated Answer", | |
| lines=6, | |
| show_copy_button=True | |
| ), | |
| title="ChatFed Generation Module", | |
| description="Ask questions based on provided context. Intended for use in RAG pipelines as an MCP server with other ChatFed modules (i.e. context supplied by semantic retriever service).", | |
| api_name="generate" | |
| ) | |
| # Mount Gradio app to FastAPI | |
| app = gr.mount_gradio_app(app, ui, path="/gradio") | |
| # Launch with MCP server enabled | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info("Starting ChatFed Generation Module server") | |
| logger.info("FastAPI server will be available at http://0.0.0.0:7860") | |
| logger.info("Gradio UI will be available at http://0.0.0.0:7860/gradio") | |
| logger.info("ChatUI endpoints: /generate (non-streaming), /generate/stream (streaming)") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |