File size: 6,863 Bytes
e3fc855
bd5c958
 
588855b
0c8a459
e3fc855
1c4883e
0c8a459
e3fc855
8be0409
aad57ad
e3fc855
 
 
 
 
 
 
bd5c958
e3fc855
bd5c958
 
e3fc855
 
 
 
 
be51c46
 
 
dec4548
8be0409
be51c46
e3fc855
0c8a459
e3fc855
 
0c8a459
e3fc855
 
 
0c8a459
e3fc855
 
 
 
 
0c8a459
e3fc855
 
 
0c8a459
 
 
 
 
 
 
 
bd5c958
0c8a459
 
 
 
330e345
a413d29
330e345
bd5c958
 
 
 
 
 
 
 
 
e3fc855
588855b
 
 
 
 
 
 
 
 
a413d29
 
0c8a459
bd5c958
 
0c8a459
aad57ad
 
 
0c8a459
 
15edc9e
d613a3b
bd5c958
 
 
 
 
 
 
 
 
 
0c8a459
 
bd5c958
0c8a459
 
bd5c958
0c8a459
8be0409
 
 
 
 
 
 
 
 
0c8a459
 
 
 
 
bd5c958
a413d29
0c8a459
262f775
 
0c8a459
aad57ad
 
 
1766ca9
 
 
 
1a86629
 
 
aad57ad
0c8a459
 
bd5c958
0c8a459
bd5c958
 
 
0c8a459
 
bd5c958
e3fc855
 
 
0c8a459
e3fc855
 
 
0c8a459
 
 
 
 
 
 
 
 
e3fc855
 
 
0c8a459
e3fc855
 
 
 
 
cc2d90e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import asyncio
import os
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict
import uvicorn
import json
import time

# Initialiser l'application FastAPI
app = FastAPI()

# Configurer CORS pour autoriser toutes les origines
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Monter un répertoire statique pour servir le fichier index.html
app.mount("/static", StaticFiles(directory="static"), name="static")

class MockRequest(BaseModel):
    """Définit la structure attendue pour le corps de la requête POST."""
    parameter: str
    model: str = None
    secret: str  

class ConnectionManager:
    """Gère les connexions WebSocket actives."""
    def __init__(self):
        self.active_connections: List[WebSocket] = []
        # Dictionnaire pour attendre les réponses des clients
        self.response_futures: Dict[str, asyncio.Future] = {}

    async def connect(self, websocket: WebSocket):
        """Accepte une nouvelle connexion WebSocket."""
        await websocket.accept()
        self.active_connections.append(websocket)
        print(f"Nouvelle connexion WebSocket. Total: {len(self.active_connections)}")

    def disconnect(self, websocket: WebSocket):
        """Ferme une connexion WebSocket."""
        self.active_connections.remove(websocket)
        print(f"Déconnexion WebSocket. Total: {len(self.active_connections)}")

    async def broadcast(self, message: str):
        """Envoie un message à tous les clients connectés."""
        # Pour ce cas simple, nous n'envoyons qu'au premier client connecté
        if self.active_connections:
            websocket = self.active_connections[0]
            await websocket.send_text(message)
            # Créer un Future pour attendre la réponse
            future = asyncio.get_event_loop().create_future()
            # Utilise l'identifiant du client comme clé
            client_id = str(id(websocket))
            self.response_futures[client_id] = future
            return future
        return None

manager = ConnectionManager()

def verify_secret(provided_secret: str) -> bool:
    """Vérifie si le secret fourni correspond à celui de la variable d'environnement."""
    expected_secret = os.getenv("API_SECRET")
    
    if not expected_secret:
        print("ATTENTION: Variable d'environnement API_SECRET non définie!")
        return False
    
    return provided_secret == expected_secret

@app.get("/", response_class=HTMLResponse)
async def root():
    """Serve the main HTML page."""
    try:
        with open("static/index.html", "r", encoding="utf-8") as f:
            return HTMLResponse(content=f.read())
    except FileNotFoundError:
        raise HTTPException(status_code=404, detail="index.html not found")

@app.post("/v1/mock")
async def mock_endpoint(payload: MockRequest):
    """
    Endpoint API qui prend un string et un secret, vérifie le secret,
    puis transmet via WebSocket, attend une réponse et la retourne.
    """

    start_time = time.monotonic()
    
    try:
        input_string = payload.parameter
        selected_model = payload.model
        provided_secret = payload.secret

        # Vérification du secret AVANT tout traitement
        if not verify_secret(provided_secret):
            print(f"Tentative d'accès avec un secret invalide: '{provided_secret[:10]}...'")
            raise HTTPException(
                status_code=401, 
                detail="Secret invalide. Accès non autorisé."
            )

        print(f"Secret vérifié avec succès. Endpoint /v1/mock appelé avec: '{input_string}'")

        if input_string is None:
            raise HTTPException(status_code=400, detail="Le paramètre 'parameter' est manquant.")
            
        if not manager.active_connections:
            raise HTTPException(status_code=503, detail="Aucun client WebSocket n'est connecté.")

        # Créer un dictionnaire avec les données à envoyer
        message_data = {
            "prompt": input_string,
            "model": selected_model
        }
        
        # Envoyer le message via WebSocket (sérialiser en JSON)
        response_future = await manager.broadcast(json.dumps(message_data))
        
        # Envoyer le message via WebSocket et obtenir un "future" pour la réponse
        print("Envoi du message au client WebSocket...")
        response_future = await manager.broadcast(input_string)

        if response_future is None:
            raise HTTPException(status_code=500, detail="Échec de la diffusion du message.")

        try:
            # Attendre la réponse du client WebSocket avec un timeout de 60 secondes
            websocket_response = await asyncio.wait_for(response_future, timeout=60.0)
            print(f"Réponse reçue du WebSocket: '{websocket_response}'")
            end_time = time.monotonic()
            duration = end_time - start_time
            print(f"Requête complétée en {duration:.2f} secondes.")
            return {
                "response_from_client": websocket_response,
                "completion_time_in_seconds": round(duration, 2)
            }




        except asyncio.TimeoutError:
            print("Timeout: Aucune réponse du client WebSocket.")
            raise HTTPException(status_code=408, detail="Timeout: Le client n'a pas répondu à temps.")

    except HTTPException:
        # Re-lever les HTTPException sans les wrapper
        raise
    except Exception as e:
        print(f"Erreur dans /v1/mock: {e}")
        raise HTTPException(status_code=500, detail=f"Une erreur interne est survenue: {str(e)}")

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """Gère la communication WebSocket avec le client."""
    await manager.connect(websocket)
    try:
        while True:
            # Attendre un message du client
            data = await websocket.receive_text()
            print(f"Message reçu du client: '{data}'")

            # Trouver le "future" correspondant et y mettre le résultat
            client_id = str(id(websocket))
            if client_id in manager.response_futures:
                manager.response_futures[client_id].set_result(data)
                del manager.response_futures[client_id] # Nettoyer après utilisation

    except WebSocketDisconnect:
        manager.disconnect(websocket)
        print("Client déconnecté.")
    except Exception as e:
        print(f"Erreur dans le WebSocket: {e}")
        manager.disconnect(websocket)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)