Update main.py
Browse files
main.py
CHANGED
|
@@ -22,7 +22,7 @@ if not SERVER_API_KEY:
|
|
| 22 |
raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
|
| 23 |
|
| 24 |
# FastAPI Init
|
| 25 |
-
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.
|
| 26 |
|
| 27 |
# --- Authentication ---
|
| 28 |
security = HTTPBearer()
|
|
@@ -52,7 +52,7 @@ class ModelList(BaseModel):
|
|
| 52 |
|
| 53 |
class ChatMessage(BaseModel):
|
| 54 |
role: Literal["system", "user", "assistant", "tool"]
|
| 55 |
-
content: Union[str, List[Dict[str, Any]]]
|
| 56 |
name: Optional[str] = None
|
| 57 |
tool_calls: Optional[List[Any]] = None
|
| 58 |
|
|
@@ -127,33 +127,37 @@ class ChatCompletionChunk(BaseModel):
|
|
| 127 |
# --- Supported Models ---
|
| 128 |
SUPPORTED_MODELS = {
|
| 129 |
"llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
|
| 130 |
-
"claude-4.5-haiku": "anthropic/claude-
|
| 131 |
-
"claude-4.5-sonnet": "anthropic/claude-
|
| 132 |
"llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
|
| 133 |
}
|
| 134 |
|
| 135 |
# --- Core Logic ---
|
| 136 |
|
| 137 |
def generate_request_id() -> str:
|
| 138 |
-
"""Generates a unique request ID
|
| 139 |
-
return f"
|
| 140 |
|
| 141 |
def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
|
| 142 |
prompt_parts = []
|
| 143 |
system_prompt = None
|
| 144 |
image_input = None
|
| 145 |
|
|
|
|
|
|
|
| 146 |
if functions:
|
| 147 |
-
|
| 148 |
for func in functions:
|
| 149 |
-
|
| 150 |
-
if func.description:
|
| 151 |
-
if func.parameters:
|
| 152 |
-
|
| 153 |
|
| 154 |
for msg in messages:
|
| 155 |
if msg.role == "system":
|
| 156 |
system_prompt = str(msg.content)
|
|
|
|
|
|
|
| 157 |
elif msg.role == "assistant":
|
| 158 |
if msg.tool_calls:
|
| 159 |
tool_calls_text = "\nTool calls:\n"
|
|
@@ -177,9 +181,12 @@ def format_messages_for_replicate(messages: List[ChatMessage], functions: Option
|
|
| 177 |
user_text_content = str(msg.content)
|
| 178 |
prompt_parts.append(f"User: {user_text_content}")
|
| 179 |
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
| 181 |
return {
|
| 182 |
-
"prompt": "\n
|
| 183 |
"system_prompt": system_prompt,
|
| 184 |
"image": image_input
|
| 185 |
}
|
|
@@ -198,7 +205,7 @@ def parse_function_call(content: str) -> Optional[Dict[str, Any]]:
|
|
| 198 |
pass
|
| 199 |
return None
|
| 200 |
|
| 201 |
-
async def stream_replicate_response(replicate_model_id: str, input_payload: dict, request_id: str):
|
| 202 |
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
|
| 203 |
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
|
| 204 |
|
|
@@ -225,7 +232,6 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
|
|
| 225 |
try:
|
| 226 |
async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
|
| 227 |
current_event = None
|
| 228 |
-
accumulated_content = ""
|
| 229 |
|
| 230 |
async for line in sse.aiter_lines():
|
| 231 |
if not line: continue
|
|
@@ -237,34 +243,42 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
|
|
| 237 |
if not raw_data: continue
|
| 238 |
|
| 239 |
try:
|
|
|
|
| 240 |
content_token = json.loads(raw_data)
|
|
|
|
|
|
|
| 241 |
except (json.JSONDecodeError, TypeError):
|
| 242 |
-
content_token = raw_data
|
| 243 |
-
|
| 244 |
-
# ### THIS IS THE FIX ###
|
| 245 |
-
# There is NO lstrip() or strip() here.
|
| 246 |
-
# This sends the raw, unmodified token from Replicate.
|
| 247 |
-
# If the log shows "HowcanI", it's because the model
|
| 248 |
-
# sent "How", "can", "I" as separate tokens.
|
| 249 |
-
|
| 250 |
-
accumulated_content += content_token
|
| 251 |
completion_tokens += 1
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
yield f"data: {chunk.json()}\n\n"
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
elif current_event == "done":
|
| 264 |
end_time = time.time()
|
| 265 |
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, inference_time=round(end_time - start_time, 3))
|
| 266 |
-
usage_chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=
|
| 267 |
-
yield f"data: {usage_chunk.
|
| 268 |
break
|
| 269 |
|
| 270 |
except httpx.ReadTimeout:
|
|
@@ -276,16 +290,10 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
|
|
| 276 |
# --- Endpoints ---
|
| 277 |
@app.get("/v1/models", dependencies=[Depends(verify_api_key)])
|
| 278 |
async def list_models():
|
| 279 |
-
"""
|
| 280 |
-
Protected endpoint to list available models.
|
| 281 |
-
"""
|
| 282 |
return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
|
| 283 |
|
| 284 |
@app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
|
| 285 |
async def create_chat_completion(request: ChatCompletionRequest):
|
| 286 |
-
"""
|
| 287 |
-
Protected endpoint to create a chat completion.
|
| 288 |
-
"""
|
| 289 |
if request.model not in SUPPORTED_MODELS:
|
| 290 |
raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
|
| 291 |
|
|
@@ -294,13 +302,17 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 294 |
|
| 295 |
replicate_input = {
|
| 296 |
"prompt": formatted["prompt"],
|
| 297 |
-
"temperature": request.temperature
|
| 298 |
-
"top_p": request.top_p
|
| 299 |
}
|
| 300 |
|
| 301 |
if request.max_tokens is not None:
|
| 302 |
replicate_input["max_new_tokens"] = request.max_tokens
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
|
| 305 |
if formatted["image"]: replicate_input["image"] = formatted["image"]
|
| 306 |
|
|
@@ -308,7 +320,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 308 |
|
| 309 |
if request.stream:
|
| 310 |
return StreamingResponse(
|
| 311 |
-
stream_replicate_response(replicate_model_id, replicate_input, request_id),
|
| 312 |
media_type="text/event-stream"
|
| 313 |
)
|
| 314 |
|
|
@@ -323,21 +335,13 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 323 |
resp.raise_for_status()
|
| 324 |
pred = resp.json()
|
| 325 |
|
| 326 |
-
# Handle
|
|
|
|
|
|
|
|
|
|
| 327 |
raw_output = pred.get("output")
|
|
|
|
| 328 |
|
| 329 |
-
if isinstance(raw_output, list):
|
| 330 |
-
output = "".join(raw_output) # Expected case: list of strings
|
| 331 |
-
elif isinstance(raw_output, str):
|
| 332 |
-
output = raw_output # Handle if it's just a single string
|
| 333 |
-
else:
|
| 334 |
-
output = ""
|
| 335 |
-
|
| 336 |
-
# ### THIS IS THE FIX ###
|
| 337 |
-
# Removed output.strip() to return the raw response.
|
| 338 |
-
# This fixes the bug where a single space (" ") response
|
| 339 |
-
# would become "" and show content: "" in the JSON.
|
| 340 |
-
|
| 341 |
end_time = time.time()
|
| 342 |
prompt_tokens = len(replicate_input.get("prompt", "")) // 4
|
| 343 |
completion_tokens = len(output) // 4
|
|
@@ -348,9 +352,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 348 |
|
| 349 |
function_call = parse_function_call(output)
|
| 350 |
if function_call:
|
| 351 |
-
tool_calls = [ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=function_call["arguments"]))]
|
| 352 |
finish_reason = "tool_calls"
|
| 353 |
-
message_content = None
|
| 354 |
|
| 355 |
return ChatCompletion(
|
| 356 |
id=request_id,
|
|
@@ -375,10 +379,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
| 375 |
|
| 376 |
@app.get("/")
|
| 377 |
async def root():
|
| 378 |
-
"""
|
| 379 |
-
Root endpoint for health checks. Does not require authentication.
|
| 380 |
-
"""
|
| 381 |
-
return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.8"}
|
| 382 |
|
| 383 |
@app.middleware("http")
|
| 384 |
async def add_performance_headers(request, call_next):
|
|
@@ -386,5 +387,5 @@ async def add_performance_headers(request, call_next):
|
|
| 386 |
response = await call_next(request)
|
| 387 |
process_time = time.time() - start_time
|
| 388 |
response.headers["X-Process-Time"] = str(round(process_time, 3))
|
| 389 |
-
response.headers["X-API-Version"] = "9.
|
| 390 |
return response
|
|
|
|
| 22 |
raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
|
| 23 |
|
| 24 |
# FastAPI Init
|
| 25 |
+
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.3.0 (Streaming Fix)")
|
| 26 |
|
| 27 |
# --- Authentication ---
|
| 28 |
security = HTTPBearer()
|
|
|
|
| 52 |
|
| 53 |
class ChatMessage(BaseModel):
|
| 54 |
role: Literal["system", "user", "assistant", "tool"]
|
| 55 |
+
content: Union[str, List[Dict[str, Any]], None] # Allow content to be None for tool calls
|
| 56 |
name: Optional[str] = None
|
| 57 |
tool_calls: Optional[List[Any]] = None
|
| 58 |
|
|
|
|
| 127 |
# --- Supported Models ---
|
| 128 |
SUPPORTED_MODELS = {
|
| 129 |
"llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
|
| 130 |
+
"claude-4.5-haiku": "anthropic/claude-3-haiku-20240307",
|
| 131 |
+
"claude-4.5-sonnet": "anthropic/claude-3.5-sonnet-20240620", # Updated to correct model ID
|
| 132 |
"llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
|
| 133 |
}
|
| 134 |
|
| 135 |
# --- Core Logic ---
|
| 136 |
|
| 137 |
def generate_request_id() -> str:
|
| 138 |
+
"""Generates a unique request ID."""
|
| 139 |
+
return f"chatcmpl-{secrets.token_hex(16)}"
|
| 140 |
|
| 141 |
def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
|
| 142 |
prompt_parts = []
|
| 143 |
system_prompt = None
|
| 144 |
image_input = None
|
| 145 |
|
| 146 |
+
# Handle functions/tools if provided
|
| 147 |
+
tools_prompt_section = ""
|
| 148 |
if functions:
|
| 149 |
+
tools_prompt_section += "You have access to the following tools. Use them if required to answer the user's question.\n\n"
|
| 150 |
for func in functions:
|
| 151 |
+
tools_prompt_section += f"- Function: {func.name}\n"
|
| 152 |
+
if func.description: tools_prompt_section += f" Description: {func.description}\n"
|
| 153 |
+
if func.parameters: tools_prompt_section += f" Parameters: {json.dumps(func.parameters)}\n"
|
| 154 |
+
tools_prompt_section += "\nTo call a function, respond with a JSON object like this: {\"name\": \"function_name\", \"arguments\": {\"arg1\": \"value1\"}}\n"
|
| 155 |
|
| 156 |
for msg in messages:
|
| 157 |
if msg.role == "system":
|
| 158 |
system_prompt = str(msg.content)
|
| 159 |
+
if tools_prompt_section:
|
| 160 |
+
system_prompt += "\n\n" + tools_prompt_section
|
| 161 |
elif msg.role == "assistant":
|
| 162 |
if msg.tool_calls:
|
| 163 |
tool_calls_text = "\nTool calls:\n"
|
|
|
|
| 181 |
user_text_content = str(msg.content)
|
| 182 |
prompt_parts.append(f"User: {user_text_content}")
|
| 183 |
|
| 184 |
+
if not system_prompt and tools_prompt_section:
|
| 185 |
+
system_prompt = tools_prompt_section
|
| 186 |
+
|
| 187 |
+
prompt_parts.append("Assistant:")
|
| 188 |
return {
|
| 189 |
+
"prompt": "\n".join(prompt_parts),
|
| 190 |
"system_prompt": system_prompt,
|
| 191 |
"image": image_input
|
| 192 |
}
|
|
|
|
| 205 |
pass
|
| 206 |
return None
|
| 207 |
|
| 208 |
+
async def stream_replicate_response(replicate_model_id: str, input_payload: dict, request_id: str, model_name: str):
|
| 209 |
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
|
| 210 |
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
|
| 211 |
|
|
|
|
| 232 |
try:
|
| 233 |
async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
|
| 234 |
current_event = None
|
|
|
|
| 235 |
|
| 236 |
async for line in sse.aiter_lines():
|
| 237 |
if not line: continue
|
|
|
|
| 243 |
if not raw_data: continue
|
| 244 |
|
| 245 |
try:
|
| 246 |
+
# Replicate sends JSON-encoded strings. This correctly handles escaped chars like \n
|
| 247 |
content_token = json.loads(raw_data)
|
| 248 |
+
if not isinstance(content_token, str):
|
| 249 |
+
content_token = str(content_token) # Ensure it's a string
|
| 250 |
except (json.JSONDecodeError, TypeError):
|
| 251 |
+
content_token = raw_data # Fallback for non-JSON data
|
| 252 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
completion_tokens += 1
|
| 254 |
|
| 255 |
+
if content_token:
|
| 256 |
+
chunk = ChatCompletionChunk(
|
| 257 |
+
id=request_id,
|
| 258 |
+
created=int(time.time()),
|
| 259 |
+
model=model_name,
|
| 260 |
+
choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token))]
|
| 261 |
+
)
|
| 262 |
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
|
|
|
| 263 |
|
| 264 |
+
# --- THIS IS THE CRITICAL FIX for incomplete responses ---
|
| 265 |
+
elif line.startswith("data:") and current_event == "error":
|
| 266 |
+
raw_data = line[5:].strip()
|
| 267 |
+
error_details = raw_data
|
| 268 |
+
try:
|
| 269 |
+
error_json = json.loads(raw_data)
|
| 270 |
+
error_details = error_json.get("detail") or str(error_json)
|
| 271 |
+
except json.JSONDecodeError: pass
|
| 272 |
+
|
| 273 |
+
error_chunk = {"error": {"message": f"Replicate stream error: {error_details}", "type": "replicate_error"}}
|
| 274 |
+
yield f"data: {json.dumps(error_chunk)}\n\n"
|
| 275 |
+
break # Stop streaming on error
|
| 276 |
+
|
| 277 |
elif current_event == "done":
|
| 278 |
end_time = time.time()
|
| 279 |
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, inference_time=round(end_time - start_time, 3))
|
| 280 |
+
usage_chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=model_name, choices=[ChoiceDelta(index=0, delta=DeltaMessage(), finish_reason="stop")], usage=usage)
|
| 281 |
+
yield f"data: {usage_chunk.model_dump_json()}\n\n"
|
| 282 |
break
|
| 283 |
|
| 284 |
except httpx.ReadTimeout:
|
|
|
|
| 290 |
# --- Endpoints ---
|
| 291 |
@app.get("/v1/models", dependencies=[Depends(verify_api_key)])
|
| 292 |
async def list_models():
|
|
|
|
|
|
|
|
|
|
| 293 |
return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
|
| 294 |
|
| 295 |
@app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
|
| 296 |
async def create_chat_completion(request: ChatCompletionRequest):
|
|
|
|
|
|
|
|
|
|
| 297 |
if request.model not in SUPPORTED_MODELS:
|
| 298 |
raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
|
| 299 |
|
|
|
|
| 302 |
|
| 303 |
replicate_input = {
|
| 304 |
"prompt": formatted["prompt"],
|
| 305 |
+
"temperature": request.temperature if request.temperature is not None else 0.7,
|
| 306 |
+
"top_p": request.top_p if request.top_p is not None else 1.0,
|
| 307 |
}
|
| 308 |
|
| 309 |
if request.max_tokens is not None:
|
| 310 |
replicate_input["max_new_tokens"] = request.max_tokens
|
| 311 |
|
| 312 |
+
# --- THIS IS THE SECOND FIX for incomplete responses ---
|
| 313 |
+
if request.stop:
|
| 314 |
+
replicate_input["stop_sequences"] = request.stop if isinstance(request.stop, list) else [request.stop]
|
| 315 |
+
|
| 316 |
if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
|
| 317 |
if formatted["image"]: replicate_input["image"] = formatted["image"]
|
| 318 |
|
|
|
|
| 320 |
|
| 321 |
if request.stream:
|
| 322 |
return StreamingResponse(
|
| 323 |
+
stream_replicate_response(replicate_model_id, replicate_input, request_id, request.model),
|
| 324 |
media_type="text/event-stream"
|
| 325 |
)
|
| 326 |
|
|
|
|
| 335 |
resp.raise_for_status()
|
| 336 |
pred = resp.json()
|
| 337 |
|
| 338 |
+
# Handle errors in non-streaming mode
|
| 339 |
+
if pred.get("status") == "failed":
|
| 340 |
+
raise HTTPException(status_code=500, detail=f"Replicate prediction failed: {pred.get('error')}")
|
| 341 |
+
|
| 342 |
raw_output = pred.get("output")
|
| 343 |
+
output = "".join(raw_output) if isinstance(raw_output, list) else (raw_output or "")
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
end_time = time.time()
|
| 346 |
prompt_tokens = len(replicate_input.get("prompt", "")) // 4
|
| 347 |
completion_tokens = len(output) // 4
|
|
|
|
| 352 |
|
| 353 |
function_call = parse_function_call(output)
|
| 354 |
if function_call:
|
| 355 |
+
tool_calls = [ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=json.dumps(function_call["arguments"])))]
|
| 356 |
finish_reason = "tool_calls"
|
| 357 |
+
message_content = None
|
| 358 |
|
| 359 |
return ChatCompletion(
|
| 360 |
id=request_id,
|
|
|
|
| 379 |
|
| 380 |
@app.get("/")
|
| 381 |
async def root():
|
| 382 |
+
return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.3.0"}
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
@app.middleware("http")
|
| 385 |
async def add_performance_headers(request, call_next):
|
|
|
|
| 387 |
response = await call_next(request)
|
| 388 |
process_time = time.time() - start_time
|
| 389 |
response.headers["X-Process-Time"] = str(round(process_time, 3))
|
| 390 |
+
response.headers["X-API-Version"] = "9.3.0"
|
| 391 |
return response
|