Spaces:
Sleeping
Sleeping
| import asyncio | |
| import os | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Dict | |
| import uvicorn | |
| import json | |
| import time | |
| # Initialiser l'application FastAPI | |
| app = FastAPI() | |
| # Configurer CORS pour autoriser toutes les origines | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Monter un répertoire statique pour servir le fichier index.html | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| class MockRequest(BaseModel): | |
| """Définit la structure attendue pour le corps de la requête POST.""" | |
| parameter: str | |
| model: str = None | |
| secret: str | |
| class ConnectionManager: | |
| """Gère les connexions WebSocket actives.""" | |
| def __init__(self): | |
| self.active_connections: List[WebSocket] = [] | |
| # Dictionnaire pour attendre les réponses des clients | |
| self.response_futures: Dict[str, asyncio.Future] = {} | |
| async def connect(self, websocket: WebSocket): | |
| """Accepte une nouvelle connexion WebSocket.""" | |
| await websocket.accept() | |
| self.active_connections.append(websocket) | |
| print(f"Nouvelle connexion WebSocket. Total: {len(self.active_connections)}") | |
| def disconnect(self, websocket: WebSocket): | |
| """Ferme une connexion WebSocket.""" | |
| self.active_connections.remove(websocket) | |
| print(f"Déconnexion WebSocket. Total: {len(self.active_connections)}") | |
| async def broadcast(self, message: str): | |
| """Envoie un message à tous les clients connectés.""" | |
| # Pour ce cas simple, nous n'envoyons qu'au premier client connecté | |
| if self.active_connections: | |
| websocket = self.active_connections[0] | |
| await websocket.send_text(message) | |
| # Créer un Future pour attendre la réponse | |
| future = asyncio.get_event_loop().create_future() | |
| # Utilise l'identifiant du client comme clé | |
| client_id = str(id(websocket)) | |
| self.response_futures[client_id] = future | |
| return future | |
| return None | |
| manager = ConnectionManager() | |
| def verify_secret(provided_secret: str) -> bool: | |
| """Vérifie si le secret fourni correspond à celui de la variable d'environnement.""" | |
| expected_secret = os.getenv("API_SECRET") | |
| if not expected_secret: | |
| print("ATTENTION: Variable d'environnement API_SECRET non définie!") | |
| return False | |
| return provided_secret == expected_secret | |
| async def root(): | |
| """Serve the main HTML page.""" | |
| try: | |
| with open("static/index.html", "r", encoding="utf-8") as f: | |
| return HTMLResponse(content=f.read()) | |
| except FileNotFoundError: | |
| raise HTTPException(status_code=404, detail="index.html not found") | |
| async def mock_endpoint(payload: MockRequest): | |
| """ | |
| Endpoint API qui prend un string et un secret, vérifie le secret, | |
| puis transmet via WebSocket, attend une réponse et la retourne. | |
| """ | |
| start_time = time.monotonic() | |
| try: | |
| input_string = payload.parameter | |
| selected_model = payload.model | |
| provided_secret = payload.secret | |
| # Vérification du secret AVANT tout traitement | |
| if not verify_secret(provided_secret): | |
| print(f"Tentative d'accès avec un secret invalide: '{provided_secret[:10]}...'") | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Secret invalide. Accès non autorisé." | |
| ) | |
| print(f"Secret vérifié avec succès. Endpoint /v1/mock appelé avec: '{input_string}'") | |
| if input_string is None: | |
| raise HTTPException(status_code=400, detail="Le paramètre 'parameter' est manquant.") | |
| if not manager.active_connections: | |
| raise HTTPException(status_code=503, detail="Aucun client WebSocket n'est connecté.") | |
| # Créer un dictionnaire avec les données à envoyer | |
| message_data = { | |
| "prompt": input_string, | |
| "model": selected_model | |
| } | |
| # Envoyer le message via WebSocket (sérialiser en JSON) | |
| response_future = await manager.broadcast(json.dumps(message_data)) | |
| # Envoyer le message via WebSocket et obtenir un "future" pour la réponse | |
| print("Envoi du message au client WebSocket...") | |
| response_future = await manager.broadcast(input_string) | |
| if response_future is None: | |
| raise HTTPException(status_code=500, detail="Échec de la diffusion du message.") | |
| try: | |
| # Attendre la réponse du client WebSocket avec un timeout de 60 secondes | |
| websocket_response = await asyncio.wait_for(response_future, timeout=60.0) | |
| print(f"Réponse reçue du WebSocket: '{websocket_response}'") | |
| end_time = time.monotonic() | |
| duration = end_time - start_time | |
| print(f"Requête complétée en {duration:.2f} secondes.") | |
| return { | |
| "response_from_client": websocket_response, | |
| "completion_time_in_seconds": round(duration, 2) | |
| } | |
| except asyncio.TimeoutError: | |
| print("Timeout: Aucune réponse du client WebSocket.") | |
| raise HTTPException(status_code=408, detail="Timeout: Le client n'a pas répondu à temps.") | |
| except HTTPException: | |
| # Re-lever les HTTPException sans les wrapper | |
| raise | |
| except Exception as e: | |
| print(f"Erreur dans /v1/mock: {e}") | |
| raise HTTPException(status_code=500, detail=f"Une erreur interne est survenue: {str(e)}") | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """Gère la communication WebSocket avec le client.""" | |
| await manager.connect(websocket) | |
| try: | |
| while True: | |
| # Attendre un message du client | |
| data = await websocket.receive_text() | |
| print(f"Message reçu du client: '{data}'") | |
| # Trouver le "future" correspondant et y mettre le résultat | |
| client_id = str(id(websocket)) | |
| if client_id in manager.response_futures: | |
| manager.response_futures[client_id].set_result(data) | |
| del manager.response_futures[client_id] # Nettoyer après utilisation | |
| except WebSocketDisconnect: | |
| manager.disconnect(websocket) | |
| print("Client déconnecté.") | |
| except Exception as e: | |
| print(f"Erreur dans le WebSocket: {e}") | |
| manager.disconnect(websocket) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |