Update main.py
Browse files
main.py
CHANGED
|
@@ -20,7 +20,7 @@ if not REPLICATE_API_TOKEN:
|
|
| 20 |
# --- FastAPI App Initialization ---
|
| 21 |
app = FastAPI(
|
| 22 |
title="Replicate to OpenAI Compatibility Layer",
|
| 23 |
-
version="2.
|
| 24 |
)
|
| 25 |
|
| 26 |
# --- Pydantic Models ---
|
|
@@ -45,14 +45,9 @@ SUPPORTED_MODELS = {
|
|
| 45 |
# --- Helper Functions ---
|
| 46 |
|
| 47 |
def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
|
| 48 |
-
"""
|
| 49 |
-
Prepares the input payload for Replicate, handling model-specific formats.
|
| 50 |
-
"""
|
| 51 |
payload = {}
|
| 52 |
|
| 53 |
-
# *** THIS IS THE CRITICAL FIX ***
|
| 54 |
-
# Claude models on Replicate require a single 'prompt' string.
|
| 55 |
-
# We must convert the 'messages' array into a formatted string.
|
| 56 |
if "claude" in request.model:
|
| 57 |
prompt_parts = []
|
| 58 |
system_prompt = None
|
|
@@ -72,6 +67,9 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
|
|
| 72 |
prompt_parts.append(f"User: {msg.content}")
|
| 73 |
elif msg.role == "assistant":
|
| 74 |
prompt_parts.append(f"Assistant: {msg.content}")
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
payload["prompt"] = "\n".join(prompt_parts)
|
| 77 |
if system_prompt:
|
|
@@ -79,11 +77,9 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
|
|
| 79 |
if image_url:
|
| 80 |
payload["image"] = image_url
|
| 81 |
|
| 82 |
-
#
|
| 83 |
-
else:
|
| 84 |
payload["messages"] = [msg.dict() for msg in request.messages]
|
| 85 |
|
| 86 |
-
# Add common parameters
|
| 87 |
if request.max_tokens is not None:
|
| 88 |
payload["max_new_tokens"] = request.max_tokens
|
| 89 |
if request.temperature is not None:
|
|
@@ -126,7 +122,10 @@ async def stream_replicate_native_sse(model_id: str, payload: dict):
|
|
| 126 |
current_event = line[len("event:"):].strip()
|
| 127 |
elif line.startswith("data:"):
|
| 128 |
data = line[len("data:"):].strip()
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
| 130 |
chunk = {
|
| 131 |
"id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
|
| 132 |
"choices": [{"index": 0, "delta": {"content": json.loads(data)}, "finish_reason": None}]
|
|
|
|
| 20 |
# --- FastAPI App Initialization ---
|
| 21 |
app = FastAPI(
|
| 22 |
title="Replicate to OpenAI Compatibility Layer",
|
| 23 |
+
version="2.2.0 (Stable Streaming)",
|
| 24 |
)
|
| 25 |
|
| 26 |
# --- Pydantic Models ---
|
|
|
|
| 45 |
# --- Helper Functions ---
|
| 46 |
|
| 47 |
def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
|
| 48 |
+
"""Prepares the input payload for Replicate, handling model-specific formats."""
|
|
|
|
|
|
|
| 49 |
payload = {}
|
| 50 |
|
|
|
|
|
|
|
|
|
|
| 51 |
if "claude" in request.model:
|
| 52 |
prompt_parts = []
|
| 53 |
system_prompt = None
|
|
|
|
| 67 |
prompt_parts.append(f"User: {msg.content}")
|
| 68 |
elif msg.role == "assistant":
|
| 69 |
prompt_parts.append(f"Assistant: {msg.content}")
|
| 70 |
+
|
| 71 |
+
# Add final turn for the assistant to respond
|
| 72 |
+
prompt_parts.append("Assistant:")
|
| 73 |
|
| 74 |
payload["prompt"] = "\n".join(prompt_parts)
|
| 75 |
if system_prompt:
|
|
|
|
| 77 |
if image_url:
|
| 78 |
payload["image"] = image_url
|
| 79 |
|
| 80 |
+
else: # Llama-3 and other standard chat models
|
|
|
|
| 81 |
payload["messages"] = [msg.dict() for msg in request.messages]
|
| 82 |
|
|
|
|
| 83 |
if request.max_tokens is not None:
|
| 84 |
payload["max_new_tokens"] = request.max_tokens
|
| 85 |
if request.temperature is not None:
|
|
|
|
| 122 |
current_event = line[len("event:"):].strip()
|
| 123 |
elif line.startswith("data:"):
|
| 124 |
data = line[len("data:"):].strip()
|
| 125 |
+
|
| 126 |
+
# *** THIS IS THE CRITICAL FIX ***
|
| 127 |
+
# Only process non-empty data for 'output' events
|
| 128 |
+
if data and current_event == "output":
|
| 129 |
chunk = {
|
| 130 |
"id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
|
| 131 |
"choices": [{"index": 0, "delta": {"content": json.loads(data)}, "finish_reason": None}]
|