File size: 6,662 Bytes
79ad53d
cfa142b
caa8809
f852f01
 
 
 
79ad53d
caa8809
 
 
 
 
 
 
 
 
 
 
f852f01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfa142b
 
 
 
 
caa8809
 
cfa142b
f852f01
cfa142b
 
 
 
 
caa8809
cfa142b
 
 
caa8809
cfa142b
 
 
 
7d8975d
 
caa8809
7d8975d
cfa142b
 
 
7d8975d
caa8809
7d8975d
cfa142b
caa8809
cfa142b
 
287959e
ec4377c
287959e
caa8809
f346328
cfa142b
f346328
 
 
 
 
 
 
 
 
 
 
 
 
 
ec4377c
 
 
cfa142b
ec4377c
 
 
 
f346328
 
f852f01
 
 
f346328
8281cfa
f852f01
caa8809
f852f01
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
import asyncio
import logging
import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from utils.generator import generate_streaming, generate

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('app.log')
    ]
)
logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------
# FastAPI app for ChatUI endpoints
# ---------------------------------------------------------------------
app = FastAPI(title="ChatFed Generator", version="1.0.0")

@app.post("/generate")
async def generate_endpoint(request: Request):
    """
    Non-streaming generation endpoint for ChatUI format.
    
    Expected request body:
    {
        "query": "user question",
        "context": [...] // list of retrieval results
    }
    
    Returns ChatUI format:
    {
        "answer": "response with citations [1][2]",
        "sources": [{"link": "doc://...", "title": "..."}]
    }
    """
    try:
        body = await request.json()
        query = body.get("query", "")
        context = body.get("context", [])
        
        result = await generate(query, context, chatui_format=True)
        return result
        
    except Exception as e:
        logger.exception("Generation endpoint failed")
        return {"error": str(e)}

@app.post("/generate/stream")
async def generate_stream_endpoint(request: Request):
    """
    Streaming generation endpoint for ChatUI format.
    
    Expected request body:
    {
        "query": "user question", 
        "context": [...] // list of retrieval results
    }
    
    Returns Server-Sent Events in ChatUI format:
    event: data
    data: "response chunk"
    
    event: sources  
    data: {"sources": [...]}
    
    event: end
    """
    try:
        body = await request.json()
        query = body.get("query", "")
        context = body.get("context", [])
        
        async def event_stream():
            async for event in generate_streaming(query, context, chatui_format=True):
                event_type = event["event"]
                event_data = event["data"]
                
                if event_type == "data":
                    yield f"event: data\ndata: {json.dumps(event_data)}\n\n"
                elif event_type == "sources":
                    yield f"event: sources\ndata: {json.dumps(event_data)}\n\n"
                elif event_type == "end":
                    yield f"event: end\ndata: {{}}\n\n"
                elif event_type == "error":
                    yield f"event: error\ndata: {json.dumps(event_data)}\n\n"
        
        return StreamingResponse(
            event_stream(),
            media_type="text/event-stream",
            headers={
                "Cache-Control": "no-cache",
                "Connection": "keep-alive",
                "Access-Control-Allow-Origin": "*",
                "Access-Control-Allow-Headers": "*",
            }
        )
        
    except Exception as e:
        logger.exception("Streaming endpoint failed")
        async def error_stream():
            yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
        
        return StreamingResponse(
            error_stream(),
            media_type="text/event-stream"
        )

# ---------------------------------------------------------------------
# Wrapper function to handle async streaming for Gradio
# ---------------------------------------------------------------------
def generate_streaming_wrapper(query: str, context: str):
    """Wrapper to convert async generator to sync generator for Gradio"""
    logger.info(f"Starting generation request - Query length: {len(query)}, Context length: {len(context)}")
    
    async def _async_generator():
        async for chunk in generate_streaming(query, context, chatui_format=False):
            yield chunk
    
    # Create a new event loop for this thread
    try:
        loop = asyncio.get_event_loop()
        logger.debug("Using existing event loop")
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        logger.debug("Created new event loop")
    
    # Convert async generator to sync generator
    async_gen = _async_generator()
    
    # Accumulate chunks for Gradio streaming
    accumulated_text = ""
    chunk_count = 0
    
    while True:
        try:
            chunk = loop.run_until_complete(async_gen.__anext__())
            accumulated_text += chunk
            chunk_count += 1
            yield accumulated_text  # Yield the accumulated text, not just the chunk
        except StopAsyncIteration:
            logger.info(f"Generation completed - Total chunks: {chunk_count}, Final text length: {len(accumulated_text)}")
            break

# ---------------------------------------------------------------------
# Gradio Interface with MCP support and streaming
# ---------------------------------------------------------------------
logger.info("Initializing Gradio interface")
ui = gr.Interface(
    fn=generate_streaming_wrapper,  # Use streaming wrapper function
    inputs=[
        gr.Textbox(
            label="Query", 
            lines=2, 
            placeholder="Enter query here",
            info="The query to search for in the vector database"
        ),
        gr.Textbox(
            label="Context", 
            lines=8, 
            placeholder="Paste relevant context here",
            info="Provide the context/documents to use for answering. The API expects a list of dictionaries, but the UI should except anything"
        ),
    ],
    outputs=gr.Textbox(
        label="Generated Answer",
        lines=6,
        show_copy_button=True
    ),
    title="ChatFed Generation Module",
    description="Ask questions based on provided context. Intended for use in RAG pipelines as an MCP server with other ChatFed modules (i.e. context supplied by semantic retriever service).",
    api_name="generate"
)

# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, ui, path="/gradio")

# Launch with MCP server enabled
if __name__ == "__main__":
    import uvicorn
    logger.info("Starting ChatFed Generation Module server")
    logger.info("FastAPI server will be available at http://0.0.0.0:7860")
    logger.info("Gradio UI will be available at http://0.0.0.0:7860/gradio")
    logger.info("ChatUI endpoints: /generate (non-streaming), /generate/stream (streaming)")
    
    uvicorn.run(app, host="0.0.0.0", port=7860)