rkihacker commited on
Commit
9f14d65
·
verified ·
1 Parent(s): ff93199

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +36 -16
main.py CHANGED
@@ -3,7 +3,8 @@ import httpx
3
  import json
4
  import time
5
  import asyncio
6
- from fastapi import FastAPI, HTTPException
 
7
  from fastapi.responses import StreamingResponse
8
  from pydantic import BaseModel, Field
9
  from typing import List, Dict, Any, Optional, Union, Literal
@@ -12,11 +13,30 @@ from dotenv import load_dotenv
12
  # Load environment variables
13
  load_dotenv()
14
  REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
 
 
15
  if not REPLICATE_API_TOKEN:
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
 
 
17
 
18
  # FastAPI Init
19
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.3 (Raw Stream)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
@@ -200,7 +220,6 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
200
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
201
  current_event = None
202
  accumulated_content = ""
203
- # first_token = True <- REMOVED THIS
204
 
205
  async for line in sse.aiter_lines():
206
  if not line: continue
@@ -216,14 +235,7 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
216
  except (json.JSONDecodeError, TypeError):
217
  content_token = raw_data
218
 
219
- # ### MAJOR FIX HERE ###
220
- # The lstrip() logic has been COMPLETELY REMOVED
221
- # to send the raw, unmodified token from Replicate.
222
- #
223
- # if first_token:
224
- # content_token = content_token.lstrip()
225
- # if content_token:
226
- # first_token = False
227
 
228
  accumulated_content += content_token
229
  completion_tokens += 1
@@ -234,7 +246,6 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
234
  chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(tool_calls=[tool_call]), finish_reason=None)])
235
  yield f"data: {chunk.json()}\n\n"
236
  else:
237
- # Only yield a chunk if there is content to send.
238
  if content_token:
239
  chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token), finish_reason=None)])
240
  yield f"data: {chunk.json()}\n\n"
@@ -253,12 +264,18 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
253
  yield "data: [DONE]\n\n"
254
 
255
  # --- Endpoints ---
256
- @app.get("/v1/models")
257
  async def list_models():
 
 
 
258
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
259
 
260
- @app.post("/v1/chat/completions")
261
  async def create_chat_completion(request: ChatCompletionRequest):
 
 
 
262
  if request.model not in SUPPORTED_MODELS:
263
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
264
 
@@ -332,7 +349,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
332
 
333
  @app.get("/")
334
  async def root():
335
- return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.3"}
 
 
 
336
 
337
  @app.middleware("http")
338
  async def add_performance_headers(request, call_next):
@@ -340,5 +360,5 @@ async def add_performance_headers(request, call_next):
340
  response = await call_next(request)
341
  process_time = time.time() - start_time
342
  response.headers["X-Process-Time"] = str(round(process_time, 3))
343
- response.headers["X-API-Version"] = "9.2.3"
344
  return response
 
3
  import json
4
  import time
5
  import asyncio
6
+ from fastapi import FastAPI, HTTPException, Security, Depends, status
7
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
  from fastapi.responses import StreamingResponse
9
  from pydantic import BaseModel, Field
10
  from typing import List, Dict, Any, Optional, Union, Literal
 
13
  # Load environment variables
14
  load_dotenv()
15
  REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
16
+ SERVER_API_KEY = os.getenv("SERVER_API_KEY") # <-- New key for server auth
17
+
18
  if not REPLICATE_API_TOKEN:
19
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
20
+ if not SERVER_API_KEY:
21
+ raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
22
 
23
  # FastAPI Init
24
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.4 (Server Auth Added)")
25
+
26
+ # --- Authentication ---
27
+ security = HTTPBearer()
28
+
29
+ async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
30
+ """
31
+ Verify the API key provided in the Authorization header.
32
+ """
33
+ if credentials.scheme != "Bearer" or credentials.credentials != SERVER_API_KEY:
34
+ raise HTTPException(
35
+ status_code=status.HTTP_401_UNAUTHORIZED,
36
+ detail="Invalid or missing API key",
37
+ headers={"WWW-Authenticate": "Bearer"},
38
+ )
39
+ return True
40
 
41
  # --- Pydantic Models ---
42
  class ModelCard(BaseModel):
 
220
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
221
  current_event = None
222
  accumulated_content = ""
 
223
 
224
  async for line in sse.aiter_lines():
225
  if not line: continue
 
235
  except (json.JSONDecodeError, TypeError):
236
  content_token = raw_data
237
 
238
+ # Removed the lstrip() logic to send raw tokens
 
 
 
 
 
 
 
239
 
240
  accumulated_content += content_token
241
  completion_tokens += 1
 
246
  chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(tool_calls=[tool_call]), finish_reason=None)])
247
  yield f"data: {chunk.json()}\n\n"
248
  else:
 
249
  if content_token:
250
  chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token), finish_reason=None)])
251
  yield f"data: {chunk.json()}\n\n"
 
264
  yield "data: [DONE]\n\n"
265
 
266
  # --- Endpoints ---
267
+ @app.get("/v1/models", dependencies=[Depends(verify_api_key)])
268
  async def list_models():
269
+ """
270
+ Protected endpoint to list available models.
271
+ """
272
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
273
 
274
+ @app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
275
  async def create_chat_completion(request: ChatCompletionRequest):
276
+ """
277
+ Protected endpoint to create a chat completion.
278
+ """
279
  if request.model not in SUPPORTED_MODELS:
280
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
281
 
 
349
 
350
  @app.get("/")
351
  async def root():
352
+ """
353
+ Root endpoint for health checks. Does not require authentication.
354
+ """
355
+ return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.4"}
356
 
357
  @app.middleware("http")
358
  async def add_performance_headers(request, call_next):
 
360
  response = await call_next(request)
361
  process_time = time.time() - start_time
362
  response.headers["X-Process-Time"] = str(round(process_time, 3))
363
+ response.headers["X-API-Version"] = "9.2.4"
364
  return response