Spaces:
Sleeping
Sleeping
| import traceback | |
| from fastapi import FastAPI, WebSocket | |
| from fastapi.responses import FileResponse | |
| import asyncio | |
| from fastapi.staticfiles import StaticFiles | |
| from contextlib import asynccontextmanager | |
| import json | |
| from fastapi import HTTPException | |
| from pydantic import BaseModel | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import List, Optional, Any, Dict | |
| from mcp_client import MCPClient | |
| mcp = MCPClient() | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = "gemini-2.5-pro-exp-03-25" | |
| messages: List[ChatMessage] | |
| tools: Optional[list] = [] | |
| max_tokens: Optional[int] = None | |
| class ChatCompletionResponseChoice(BaseModel): | |
| index: int = 0 | |
| message: ChatMessage | |
| finish_reason: str = "stop" | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[ChatCompletionResponseChoice] | |
| async def lifespan(app: FastAPI): | |
| try: | |
| await mcp.connect() | |
| print("Connexion au MCP réussi !") | |
| except Exception as e: | |
| print("Warning ! : Connexion au MCP impossible\n", str(e)) | |
| yield | |
| if mcp.session: | |
| try: | |
| await mcp.exit_stack.aclose() | |
| print("MCP déconnecté !") | |
| except Exception as e: | |
| print("Erreur à la fermeture du MCP\n", str(e)) | |
| app = FastAPI(lifespan=lifespan) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_credentials=True, | |
| allow_headers=["*"], | |
| allow_methods=["*"], | |
| allow_origins=["*"] | |
| ) | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active_connections = {} | |
| self.response_queues = {} | |
| async def connect(self, websocket: WebSocket): | |
| await websocket.accept() | |
| self.active_connections[websocket] = None | |
| def set_source(self, websocket: WebSocket, source: str): | |
| if websocket in self.active_connections: | |
| self.active_connections[websocket] = source | |
| async def send_to_dest(self, destination: str, message: str): | |
| for ws, src in self.active_connections.items(): | |
| if src == destination: | |
| await ws.send_text(message) | |
| def remove(self, websocket: WebSocket): | |
| if websocket in self.active_connections: | |
| del self.active_connections[websocket] | |
| async def wait_for_response(self, request_id: str, timeout: int = 30): | |
| queue = asyncio.Queue(maxsize=1) | |
| self.response_queues[request_id] = queue | |
| try: | |
| return await asyncio.wait_for(queue.get(), timeout=timeout) | |
| finally: | |
| self.response_queues.pop(request_id, None) | |
| manager = ConnectionManager() | |
| async def index_page(): | |
| return FileResponse("index.html") | |
| # @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) | |
| # async def chat_completions(request: ChatCompletionRequest): | |
| # request_id = str(uuid.uuid4()) | |
| # proxy_ws = next((ws for ws, src in manager.active_connections.items() if src == "proxy"), None) | |
| # if not proxy_ws: | |
| # raise HTTPException(503, "Proxy client not connected !") | |
| # user_msg = next((m for m in request.messages if m.role == "user"), None) | |
| # if not user_msg: | |
| # raise HTTPException(400, "No user message found !") | |
| # proxy_msg = { | |
| # "request_id": request_id, | |
| # "content": user_msg.content, | |
| # "source": "api", | |
| # "destination": "proxy", | |
| # "model": request.model, | |
| # "tools": request.tools, | |
| # "max_tokens": request.max_tokens | |
| # } | |
| # await proxy_ws.send_text(json.dumps(proxy_msg)) | |
| # try: | |
| # response_content = await manager.wait_for_response(request_id) | |
| # except asyncio.TimeoutError: | |
| # raise HTTPException(504, "Proxy response timeout") | |
| # return ChatCompletionResponse( | |
| # id=request_id, | |
| # created=int(time.time()), | |
| # model=request.model, | |
| # choices=[ChatCompletionResponseChoice( | |
| # message=ChatMessage(role="assistant", content=response_content) | |
| # )] | |
| # ) | |
| class ToolCallRequest(BaseModel): | |
| tool_calls: List[Dict[str, Any]] | |
| async def list_tools(): | |
| if not mcp.session: | |
| try: | |
| await mcp.connect() | |
| except Exception as e: | |
| raise HTTPException(status_code=503, detail=f"Connexion au MCP impossible !\n{str(e)}") | |
| try: | |
| tools = await mcp.list_tools() | |
| return tools | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Erreur lors de la récupération des outils: {str(e)}") | |
| async def call_tools(request: ToolCallRequest): | |
| if not mcp.session: | |
| try: | |
| await mcp.connect() | |
| except Exception as e: | |
| raise HTTPException(status_code=503, detail=f"Erreur lors de la récupération des outils: {str(e)}") | |
| try: | |
| result_tools = [] | |
| for tool_call in request.tool_calls: | |
| print(tool_call) | |
| tool = tool_call["function"] | |
| tool_name = tool["name"] | |
| tool_args = tool["arguments"] | |
| result = await mcp.session.call_tool(tool_name, json.loads(tool_args)) | |
| result_tools.append({ | |
| "role": "user", | |
| "content": result.content[0].text | |
| }) | |
| print("Finished !") | |
| return result_tools | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Erreur lors de l'appel des outils: {str(e)}") | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await manager.connect(websocket) | |
| try: | |
| data = await websocket.receive_text() | |
| init_msg = json.loads(data) | |
| if 'source' in init_msg: | |
| manager.set_source(websocket, init_msg['source']) | |
| print(init_msg['source']) | |
| while True: | |
| message = await websocket.receive_text() | |
| msg_data = json.loads(message) | |
| await manager.send_to_dest(msg_data["destination"], message) | |
| except Exception as e: | |
| manager.remove(websocket) | |
| await websocket.close() |