Update main.py
Browse files
main.py
CHANGED
|
@@ -22,7 +22,7 @@ if not SERVER_API_KEY:
|
|
| 22 |
raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
|
| 23 |
|
| 24 |
# FastAPI Init
|
| 25 |
-
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.
|
| 26 |
|
| 27 |
# --- Authentication ---
|
| 28 |
security = HTTPBearer()
|
|
@@ -241,7 +241,8 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
|
|
| 241 |
except (json.JSONDecodeError, TypeError):
|
| 242 |
content_token = raw_data
|
| 243 |
|
| 244 |
-
#
|
|
|
|
| 245 |
|
| 246 |
accumulated_content += content_token
|
| 247 |
completion_tokens += 1
|
|
@@ -294,7 +295,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 294 |
"top_p": request.top_p or 1.0
|
| 295 |
}
|
| 296 |
|
| 297 |
-
# Only add max_new_tokens if the user *actually* provided it.
|
| 298 |
if request.max_tokens is not None:
|
| 299 |
replicate_input["max_new_tokens"] = request.max_tokens
|
| 300 |
|
|
@@ -320,8 +320,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 320 |
resp.raise_for_status()
|
| 321 |
pred = resp.json()
|
| 322 |
|
| 323 |
-
#
|
| 324 |
-
# Robustly handle the 'output' field which could be a list, string, or null
|
| 325 |
raw_output = pred.get("output")
|
| 326 |
|
| 327 |
if isinstance(raw_output, list):
|
|
@@ -329,10 +328,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 329 |
elif isinstance(raw_output, str):
|
| 330 |
output = raw_output # Handle if it's just a single string
|
| 331 |
else:
|
| 332 |
-
# Handle None, null, int, bool, or other unexpected types
|
| 333 |
output = ""
|
| 334 |
-
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
end_time = time.time()
|
| 338 |
prompt_tokens = len(replicate_input.get("prompt", "")) // 4
|
|
@@ -367,7 +368,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 367 |
except httpx.HTTPStatusError as e:
|
| 368 |
raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")
|
| 369 |
except Exception as e:
|
| 370 |
-
# Catch the join error and any others
|
| 371 |
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
| 372 |
|
| 373 |
@app.get("/")
|
|
@@ -375,7 +375,7 @@ async def root():
|
|
| 375 |
"""
|
| 376 |
Root endpoint for health checks. Does not require authentication.
|
| 377 |
"""
|
| 378 |
-
return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.
|
| 379 |
|
| 380 |
@app.middleware("http")
|
| 381 |
async def add_performance_headers(request, call_next):
|
|
@@ -383,5 +383,5 @@ async def add_performance_headers(request, call_next):
|
|
| 383 |
response = await call_next(request)
|
| 384 |
process_time = time.time() - start_time
|
| 385 |
response.headers["X-Process-Time"] = str(round(process_time, 3))
|
| 386 |
-
response.headers["X-API-Version"] = "9.2.
|
| 387 |
return response
|
|
|
|
| 22 |
raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
|
| 23 |
|
| 24 |
# FastAPI Init
|
| 25 |
+
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.8 (Raw Output Fix)")
|
| 26 |
|
| 27 |
# --- Authentication ---
|
| 28 |
security = HTTPBearer()
|
|
|
|
| 241 |
except (json.JSONDecodeError, TypeError):
|
| 242 |
content_token = raw_data
|
| 243 |
|
| 244 |
+
# There is NO lstrip() or strip() here.
|
| 245 |
+
# This sends the raw, unmodified token.
|
| 246 |
|
| 247 |
accumulated_content += content_token
|
| 248 |
completion_tokens += 1
|
|
|
|
| 295 |
"top_p": request.top_p or 1.0
|
| 296 |
}
|
| 297 |
|
|
|
|
| 298 |
if request.max_tokens is not None:
|
| 299 |
replicate_input["max_new_tokens"] = request.max_tokens
|
| 300 |
|
|
|
|
| 320 |
resp.raise_for_status()
|
| 321 |
pred = resp.json()
|
| 322 |
|
| 323 |
+
# Handle the 'output' field which could be a list, string, or null
|
|
|
|
| 324 |
raw_output = pred.get("output")
|
| 325 |
|
| 326 |
if isinstance(raw_output, list):
|
|
|
|
| 328 |
elif isinstance(raw_output, str):
|
| 329 |
output = raw_output # Handle if it's just a single string
|
| 330 |
else:
|
|
|
|
| 331 |
output = ""
|
| 332 |
+
|
| 333 |
+
# ### MAJOR FIX HERE (Non-Streaming) ###
|
| 334 |
+
# Removed output.strip() to return the raw response,
|
| 335 |
+
# even if it's just a space.
|
| 336 |
+
# output = output.strip() # <-- REMOVED
|
| 337 |
|
| 338 |
end_time = time.time()
|
| 339 |
prompt_tokens = len(replicate_input.get("prompt", "")) // 4
|
|
|
|
| 368 |
except httpx.HTTPStatusError as e:
|
| 369 |
raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")
|
| 370 |
except Exception as e:
|
|
|
|
| 371 |
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
| 372 |
|
| 373 |
@app.get("/")
|
|
|
|
| 375 |
"""
|
| 376 |
Root endpoint for health checks. Does not require authentication.
|
| 377 |
"""
|
| 378 |
+
return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.8"}
|
| 379 |
|
| 380 |
@app.middleware("http")
|
| 381 |
async def add_performance_headers(request, call_next):
|
|
|
|
| 383 |
response = await call_next(request)
|
| 384 |
process_time = time.time() - start_time
|
| 385 |
response.headers["X-Process-Time"] = str(round(process_time, 3))
|
| 386 |
+
response.headers["X-API-Version"] = "9.2.8"
|
| 387 |
return response
|