Update main.py
Browse files
main.py
CHANGED
|
@@ -3,7 +3,8 @@ import httpx
|
|
| 3 |
import json
|
| 4 |
import time
|
| 5 |
import asyncio
|
| 6 |
-
from fastapi import FastAPI, HTTPException
|
|
|
|
| 7 |
from fastapi.responses import StreamingResponse
|
| 8 |
from pydantic import BaseModel, Field
|
| 9 |
from typing import List, Dict, Any, Optional, Union, Literal
|
|
@@ -12,11 +13,30 @@ from dotenv import load_dotenv
|
|
| 12 |
# Load environment variables
|
| 13 |
load_dotenv()
|
| 14 |
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
|
|
|
|
|
|
|
| 15 |
if not REPLICATE_API_TOKEN:
|
| 16 |
raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# FastAPI Init
|
| 19 |
-
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# --- Pydantic Models ---
|
| 22 |
class ModelCard(BaseModel):
|
|
@@ -200,7 +220,6 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
|
|
| 200 |
async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
|
| 201 |
current_event = None
|
| 202 |
accumulated_content = ""
|
| 203 |
-
# first_token = True <- REMOVED THIS
|
| 204 |
|
| 205 |
async for line in sse.aiter_lines():
|
| 206 |
if not line: continue
|
|
@@ -216,14 +235,7 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
|
|
| 216 |
except (json.JSONDecodeError, TypeError):
|
| 217 |
content_token = raw_data
|
| 218 |
|
| 219 |
-
#
|
| 220 |
-
# The lstrip() logic has been COMPLETELY REMOVED
|
| 221 |
-
# to send the raw, unmodified token from Replicate.
|
| 222 |
-
#
|
| 223 |
-
# if first_token:
|
| 224 |
-
# content_token = content_token.lstrip()
|
| 225 |
-
# if content_token:
|
| 226 |
-
# first_token = False
|
| 227 |
|
| 228 |
accumulated_content += content_token
|
| 229 |
completion_tokens += 1
|
|
@@ -234,7 +246,6 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
|
|
| 234 |
chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(tool_calls=[tool_call]), finish_reason=None)])
|
| 235 |
yield f"data: {chunk.json()}\n\n"
|
| 236 |
else:
|
| 237 |
-
# Only yield a chunk if there is content to send.
|
| 238 |
if content_token:
|
| 239 |
chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token), finish_reason=None)])
|
| 240 |
yield f"data: {chunk.json()}\n\n"
|
|
@@ -253,12 +264,18 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
|
|
| 253 |
yield "data: [DONE]\n\n"
|
| 254 |
|
| 255 |
# --- Endpoints ---
|
| 256 |
-
@app.get("/v1/models")
|
| 257 |
async def list_models():
|
|
|
|
|
|
|
|
|
|
| 258 |
return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
|
| 259 |
|
| 260 |
-
@app.post("/v1/chat/completions")
|
| 261 |
async def create_chat_completion(request: ChatCompletionRequest):
|
|
|
|
|
|
|
|
|
|
| 262 |
if request.model not in SUPPORTED_MODELS:
|
| 263 |
raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
|
| 264 |
|
|
@@ -332,7 +349,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 332 |
|
| 333 |
@app.get("/")
|
| 334 |
async def root():
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
@app.middleware("http")
|
| 338 |
async def add_performance_headers(request, call_next):
|
|
@@ -340,5 +360,5 @@ async def add_performance_headers(request, call_next):
|
|
| 340 |
response = await call_next(request)
|
| 341 |
process_time = time.time() - start_time
|
| 342 |
response.headers["X-Process-Time"] = str(round(process_time, 3))
|
| 343 |
-
response.headers["X-API-Version"] = "9.2.
|
| 344 |
return response
|
|
|
|
| 3 |
import json
|
| 4 |
import time
|
| 5 |
import asyncio
|
| 6 |
+
from fastapi import FastAPI, HTTPException, Security, Depends, status
|
| 7 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 8 |
from fastapi.responses import StreamingResponse
|
| 9 |
from pydantic import BaseModel, Field
|
| 10 |
from typing import List, Dict, Any, Optional, Union, Literal
|
|
|
|
| 13 |
# Load environment variables
|
| 14 |
load_dotenv()
|
| 15 |
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
|
| 16 |
+
SERVER_API_KEY = os.getenv("SERVER_API_KEY") # <-- New key for server auth
|
| 17 |
+
|
| 18 |
if not REPLICATE_API_TOKEN:
|
| 19 |
raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
|
| 20 |
+
if not SERVER_API_KEY:
|
| 21 |
+
raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
|
| 22 |
|
| 23 |
# FastAPI Init
|
| 24 |
+
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.4 (Server Auth Added)")
|
| 25 |
+
|
| 26 |
+
# --- Authentication ---
|
| 27 |
+
security = HTTPBearer()
|
| 28 |
+
|
| 29 |
+
async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
|
| 30 |
+
"""
|
| 31 |
+
Verify the API key provided in the Authorization header.
|
| 32 |
+
"""
|
| 33 |
+
if credentials.scheme != "Bearer" or credentials.credentials != SERVER_API_KEY:
|
| 34 |
+
raise HTTPException(
|
| 35 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 36 |
+
detail="Invalid or missing API key",
|
| 37 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 38 |
+
)
|
| 39 |
+
return True
|
| 40 |
|
| 41 |
# --- Pydantic Models ---
|
| 42 |
class ModelCard(BaseModel):
|
|
|
|
| 220 |
async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
|
| 221 |
current_event = None
|
| 222 |
accumulated_content = ""
|
|
|
|
| 223 |
|
| 224 |
async for line in sse.aiter_lines():
|
| 225 |
if not line: continue
|
|
|
|
| 235 |
except (json.JSONDecodeError, TypeError):
|
| 236 |
content_token = raw_data
|
| 237 |
|
| 238 |
+
# Removed the lstrip() logic to send raw tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
accumulated_content += content_token
|
| 241 |
completion_tokens += 1
|
|
|
|
| 246 |
chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(tool_calls=[tool_call]), finish_reason=None)])
|
| 247 |
yield f"data: {chunk.json()}\n\n"
|
| 248 |
else:
|
|
|
|
| 249 |
if content_token:
|
| 250 |
chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token), finish_reason=None)])
|
| 251 |
yield f"data: {chunk.json()}\n\n"
|
|
|
|
| 264 |
yield "data: [DONE]\n\n"
|
| 265 |
|
| 266 |
# --- Endpoints ---
|
| 267 |
+
@app.get("/v1/models", dependencies=[Depends(verify_api_key)])
|
| 268 |
async def list_models():
|
| 269 |
+
"""
|
| 270 |
+
Protected endpoint to list available models.
|
| 271 |
+
"""
|
| 272 |
return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
|
| 273 |
|
| 274 |
+
@app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
|
| 275 |
async def create_chat_completion(request: ChatCompletionRequest):
|
| 276 |
+
"""
|
| 277 |
+
Protected endpoint to create a chat completion.
|
| 278 |
+
"""
|
| 279 |
if request.model not in SUPPORTED_MODELS:
|
| 280 |
raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
|
| 281 |
|
|
|
|
| 349 |
|
| 350 |
@app.get("/")
|
| 351 |
async def root():
|
| 352 |
+
"""
|
| 353 |
+
Root endpoint for health checks. Does not require authentication.
|
| 354 |
+
"""
|
| 355 |
+
return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.4"}
|
| 356 |
|
| 357 |
@app.middleware("http")
|
| 358 |
async def add_performance_headers(request, call_next):
|
|
|
|
| 360 |
response = await call_next(request)
|
| 361 |
process_time = time.time() - start_time
|
| 362 |
response.headers["X-Process-Time"] = str(round(process_time, 3))
|
| 363 |
+
response.headers["X-API-Version"] = "9.2.4"
|
| 364 |
return response
|