Update main.py
Browse files
main.py
CHANGED
|
@@ -5,7 +5,7 @@ import json
|
|
| 5 |
import time
|
| 6 |
import asyncio
|
| 7 |
from fastapi import FastAPI, HTTPException
|
| 8 |
-
from fastapi.responses import StreamingResponse
|
| 9 |
from pydantic import BaseModel, Field
|
| 10 |
from typing import List, Dict, Any, Optional, Union, Literal
|
| 11 |
from dotenv import load_dotenv
|
|
@@ -17,7 +17,7 @@ if not REPLICATE_API_TOKEN:
|
|
| 17 |
raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
|
| 18 |
|
| 19 |
# FastAPI Init
|
| 20 |
-
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.
|
| 21 |
|
| 22 |
# --- Pydantic Models ---
|
| 23 |
class ModelCard(BaseModel):
|
|
@@ -106,6 +106,7 @@ def format_messages_for_replicate(messages: List[ChatMessage], functions: Option
|
|
| 106 |
user_text_content = str(msg.content)
|
| 107 |
prompt_parts.append(f"User: {user_text_content}")
|
| 108 |
|
|
|
|
| 109 |
prompt_parts.append("Assistant:")
|
| 110 |
return {
|
| 111 |
"prompt": "\n\n".join(prompt_parts),
|
|
@@ -161,6 +162,7 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
|
|
| 161 |
async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
|
| 162 |
current_event = None
|
| 163 |
accumulated_content = ""
|
|
|
|
| 164 |
|
| 165 |
async for line in sse.aiter_lines():
|
| 166 |
if not line: continue
|
|
@@ -177,6 +179,11 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
|
|
| 177 |
except (json.JSONDecodeError, TypeError):
|
| 178 |
content_token = raw_data
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
accumulated_content += content_token
|
| 181 |
completion_tokens += 1
|
| 182 |
|
|
@@ -291,6 +298,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 291 |
pred = resp.json()
|
| 292 |
output = "".join(pred.get("output", []))
|
| 293 |
|
|
|
|
|
|
|
|
|
|
| 294 |
# Calculate timing and tokens
|
| 295 |
end_time = time.time()
|
| 296 |
inference_time = end_time - start_time
|
|
@@ -337,7 +347,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 337 |
|
| 338 |
@app.get("/")
|
| 339 |
async def root():
|
| 340 |
-
return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.
|
| 341 |
|
| 342 |
# Performance optimization middleware
|
| 343 |
@app.middleware("http")
|
|
@@ -346,5 +356,5 @@ async def add_performance_headers(request, call_next):
|
|
| 346 |
response = await call_next(request)
|
| 347 |
process_time = time.time() - start_time
|
| 348 |
response.headers["X-Process-Time"] = str(round(process_time, 3))
|
| 349 |
-
response.headers["X-API-Version"] = "9.2.
|
| 350 |
-
return response
|
|
|
|
| 5 |
import time
|
| 6 |
import asyncio
|
| 7 |
from fastapi import FastAPI, HTTPException
|
| 8 |
+
from fastapi.responses import StreamingResponse
|
| 9 |
from pydantic import BaseModel, Field
|
| 10 |
from typing import List, Dict, Any, Optional, Union, Literal
|
| 11 |
from dotenv import load_dotenv
|
|
|
|
| 17 |
raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
|
| 18 |
|
| 19 |
# FastAPI Init
|
| 20 |
+
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.1 (Spacing Fixed)")
|
| 21 |
|
| 22 |
# --- Pydantic Models ---
|
| 23 |
class ModelCard(BaseModel):
|
|
|
|
| 106 |
user_text_content = str(msg.content)
|
| 107 |
prompt_parts.append(f"User: {user_text_content}")
|
| 108 |
|
| 109 |
+
# Fix: Don't add trailing space, let model decide spacing
|
| 110 |
prompt_parts.append("Assistant:")
|
| 111 |
return {
|
| 112 |
"prompt": "\n\n".join(prompt_parts),
|
|
|
|
| 162 |
async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
|
| 163 |
current_event = None
|
| 164 |
accumulated_content = ""
|
| 165 |
+
first_token = True
|
| 166 |
|
| 167 |
async for line in sse.aiter_lines():
|
| 168 |
if not line: continue
|
|
|
|
| 179 |
except (json.JSONDecodeError, TypeError):
|
| 180 |
content_token = raw_data
|
| 181 |
|
| 182 |
+
# Fix: Handle spacing properly - don't prepend space to first token
|
| 183 |
+
if first_token:
|
| 184 |
+
content_token = content_token.lstrip()
|
| 185 |
+
first_token = False
|
| 186 |
+
|
| 187 |
accumulated_content += content_token
|
| 188 |
completion_tokens += 1
|
| 189 |
|
|
|
|
| 298 |
pred = resp.json()
|
| 299 |
output = "".join(pred.get("output", []))
|
| 300 |
|
| 301 |
+
# Fix: Clean up leading/trailing whitespace
|
| 302 |
+
output = output.strip()
|
| 303 |
+
|
| 304 |
# Calculate timing and tokens
|
| 305 |
end_time = time.time()
|
| 306 |
inference_time = end_time - start_time
|
|
|
|
| 347 |
|
| 348 |
@app.get("/")
|
| 349 |
async def root():
|
| 350 |
+
return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.1"}
|
| 351 |
|
| 352 |
# Performance optimization middleware
|
| 353 |
@app.middleware("http")
|
|
|
|
| 356 |
response = await call_next(request)
|
| 357 |
process_time = time.time() - start_time
|
| 358 |
response.headers["X-Process-Time"] = str(round(process_time, 3))
|
| 359 |
+
response.headers["X-API-Version"] = "9.2.1"
|
| 360 |
+
return response
|