Spaces:
Sleeping
Sleeping
File size: 6,662 Bytes
79ad53d cfa142b caa8809 f852f01 79ad53d caa8809 f852f01 cfa142b caa8809 cfa142b f852f01 cfa142b caa8809 cfa142b caa8809 cfa142b 7d8975d caa8809 7d8975d cfa142b 7d8975d caa8809 7d8975d cfa142b caa8809 cfa142b 287959e ec4377c 287959e caa8809 f346328 cfa142b f346328 ec4377c cfa142b ec4377c f346328 f852f01 f346328 8281cfa f852f01 caa8809 f852f01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import gradio as gr
import asyncio
import logging
import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from utils.generator import generate_streaming, generate
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('app.log')
]
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------
# FastAPI app for ChatUI endpoints
# ---------------------------------------------------------------------
app = FastAPI(title="ChatFed Generator", version="1.0.0")
@app.post("/generate")
async def generate_endpoint(request: Request):
"""
Non-streaming generation endpoint for ChatUI format.
Expected request body:
{
"query": "user question",
"context": [...] // list of retrieval results
}
Returns ChatUI format:
{
"answer": "response with citations [1][2]",
"sources": [{"link": "doc://...", "title": "..."}]
}
"""
try:
body = await request.json()
query = body.get("query", "")
context = body.get("context", [])
result = await generate(query, context, chatui_format=True)
return result
except Exception as e:
logger.exception("Generation endpoint failed")
return {"error": str(e)}
@app.post("/generate/stream")
async def generate_stream_endpoint(request: Request):
"""
Streaming generation endpoint for ChatUI format.
Expected request body:
{
"query": "user question",
"context": [...] // list of retrieval results
}
Returns Server-Sent Events in ChatUI format:
event: data
data: "response chunk"
event: sources
data: {"sources": [...]}
event: end
"""
try:
body = await request.json()
query = body.get("query", "")
context = body.get("context", [])
async def event_stream():
async for event in generate_streaming(query, context, chatui_format=True):
event_type = event["event"]
event_data = event["data"]
if event_type == "data":
yield f"event: data\ndata: {json.dumps(event_data)}\n\n"
elif event_type == "sources":
yield f"event: sources\ndata: {json.dumps(event_data)}\n\n"
elif event_type == "end":
yield f"event: end\ndata: {{}}\n\n"
elif event_type == "error":
yield f"event: error\ndata: {json.dumps(event_data)}\n\n"
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
}
)
except Exception as e:
logger.exception("Streaming endpoint failed")
async def error_stream():
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
error_stream(),
media_type="text/event-stream"
)
# ---------------------------------------------------------------------
# Wrapper function to handle async streaming for Gradio
# ---------------------------------------------------------------------
def generate_streaming_wrapper(query: str, context: str):
"""Wrapper to convert async generator to sync generator for Gradio"""
logger.info(f"Starting generation request - Query length: {len(query)}, Context length: {len(context)}")
async def _async_generator():
async for chunk in generate_streaming(query, context, chatui_format=False):
yield chunk
# Create a new event loop for this thread
try:
loop = asyncio.get_event_loop()
logger.debug("Using existing event loop")
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
logger.debug("Created new event loop")
# Convert async generator to sync generator
async_gen = _async_generator()
# Accumulate chunks for Gradio streaming
accumulated_text = ""
chunk_count = 0
while True:
try:
chunk = loop.run_until_complete(async_gen.__anext__())
accumulated_text += chunk
chunk_count += 1
yield accumulated_text # Yield the accumulated text, not just the chunk
except StopAsyncIteration:
logger.info(f"Generation completed - Total chunks: {chunk_count}, Final text length: {len(accumulated_text)}")
break
# ---------------------------------------------------------------------
# Gradio Interface with MCP support and streaming
# ---------------------------------------------------------------------
logger.info("Initializing Gradio interface")
ui = gr.Interface(
fn=generate_streaming_wrapper, # Use streaming wrapper function
inputs=[
gr.Textbox(
label="Query",
lines=2,
placeholder="Enter query here",
info="The query to search for in the vector database"
),
gr.Textbox(
label="Context",
lines=8,
placeholder="Paste relevant context here",
info="Provide the context/documents to use for answering. The API expects a list of dictionaries, but the UI should except anything"
),
],
outputs=gr.Textbox(
label="Generated Answer",
lines=6,
show_copy_button=True
),
title="ChatFed Generation Module",
description="Ask questions based on provided context. Intended for use in RAG pipelines as an MCP server with other ChatFed modules (i.e. context supplied by semantic retriever service).",
api_name="generate"
)
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, ui, path="/gradio")
# Launch with MCP server enabled
if __name__ == "__main__":
import uvicorn
logger.info("Starting ChatFed Generation Module server")
logger.info("FastAPI server will be available at http://0.0.0.0:7860")
logger.info("Gradio UI will be available at http://0.0.0.0:7860/gradio")
logger.info("ChatUI endpoints: /generate (non-streaming), /generate/stream (streaming)")
uvicorn.run(app, host="0.0.0.0", port=7860)
|