Update main.py
Browse files
main.py
CHANGED
|
@@ -16,7 +16,7 @@ if not REPLICATE_API_TOKEN:
|
|
| 16 |
raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
|
| 17 |
|
| 18 |
# FastAPI Init
|
| 19 |
-
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="7.
|
| 20 |
|
| 21 |
# --- Pydantic Models ---
|
| 22 |
class ModelCard(BaseModel):
|
|
@@ -39,51 +39,41 @@ SUPPORTED_MODELS = {
|
|
| 39 |
# --- Core Logic ---
|
| 40 |
def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
|
| 41 |
"""
|
| 42 |
-
Formats the input for
|
| 43 |
-
|
| 44 |
-
|
| 45 |
"""
|
| 46 |
payload = {}
|
| 47 |
-
|
| 48 |
prompt_parts = []
|
| 49 |
system_prompt = None
|
| 50 |
image_input = None
|
| 51 |
|
| 52 |
for msg in request.messages:
|
| 53 |
if msg.role == "system":
|
| 54 |
-
# Extract system prompt; it will be a separate parameter.
|
| 55 |
system_prompt = str(msg.content)
|
| 56 |
elif msg.role == "assistant":
|
| 57 |
prompt_parts.append(f"Assistant: {msg.content}")
|
| 58 |
elif msg.role == "user":
|
| 59 |
user_text_content = ""
|
| 60 |
if isinstance(msg.content, list):
|
| 61 |
-
# Handle multimodal (vision) input from OpenAI format
|
| 62 |
for item in msg.content:
|
| 63 |
if item.get("type") == "text":
|
| 64 |
user_text_content += item.get("text", "")
|
| 65 |
elif item.get("type") == "image_url":
|
| 66 |
image_url_data = item.get("image_url", {})
|
| 67 |
-
# The 'image' parameter is used by Claude, Llava, etc., on Replicate
|
| 68 |
image_input = image_url_data.get("url")
|
| 69 |
else:
|
| 70 |
user_text_content = str(msg.content)
|
| 71 |
-
|
| 72 |
prompt_parts.append(f"User: {user_text_content}")
|
| 73 |
|
| 74 |
-
# The final "Assistant:" turn prompts the model for a response.
|
| 75 |
prompt_parts.append("Assistant:")
|
| 76 |
-
|
| 77 |
-
# All models on Replicate's API expect a single 'prompt' string.
|
| 78 |
payload["prompt"] = "\n\n".join(prompt_parts)
|
| 79 |
|
| 80 |
if system_prompt:
|
| 81 |
payload["system_prompt"] = system_prompt
|
| 82 |
-
|
| 83 |
if image_input:
|
| 84 |
payload["image"] = image_input
|
| 85 |
|
| 86 |
-
# Map common OpenAI parameters to Replicate equivalents
|
| 87 |
if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
|
| 88 |
if request.temperature: payload["temperature"] = request.temperature
|
| 89 |
if request.top_p: payload["top_p"] = request.top_p
|
|
@@ -91,7 +81,7 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
|
|
| 91 |
return payload
|
| 92 |
|
| 93 |
async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
|
| 94 |
-
"""Handles the full streaming lifecycle
|
| 95 |
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
|
| 96 |
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
|
| 97 |
|
|
@@ -102,20 +92,17 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
|
|
| 102 |
prediction = response.json()
|
| 103 |
stream_url = prediction.get("urls", {}).get("stream")
|
| 104 |
prediction_id = prediction.get("id", "stream-unknown")
|
| 105 |
-
|
| 106 |
if not stream_url:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
except httpx.HTTPStatusError as e:
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
return
|
| 119 |
|
| 120 |
try:
|
| 121 |
async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
|
|
@@ -126,12 +113,29 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
|
|
| 126 |
elif line.startswith("data:"):
|
| 127 |
data = line[len("data:"):].strip()
|
| 128 |
if current_event == "output":
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
chunk = {
|
| 131 |
"id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
|
| 132 |
-
"choices": [{"index": 0, "delta": {"content":
|
| 133 |
}
|
| 134 |
yield json.dumps(chunk)
|
|
|
|
| 135 |
elif current_event == "done":
|
| 136 |
break
|
| 137 |
except httpx.ReadTimeout:
|
|
@@ -148,12 +152,10 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
|
|
| 148 |
# --- Endpoints ---
|
| 149 |
@app.get("/v1/models")
|
| 150 |
async def list_models():
|
| 151 |
-
"""Lists the currently supported models."""
|
| 152 |
return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
|
| 153 |
|
| 154 |
@app.post("/v1/chat/completions")
|
| 155 |
async def create_chat_completion(request: OpenAIChatCompletionRequest):
|
| 156 |
-
"""Handles chat completion requests, streaming or non-streaming."""
|
| 157 |
if request.model not in SUPPORTED_MODELS:
|
| 158 |
raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
|
| 159 |
|
|
|
|
| 16 |
raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
|
| 17 |
|
| 18 |
# FastAPI Init
|
| 19 |
+
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="7.1.0 (Streaming Space Fix)")
|
| 20 |
|
| 21 |
# --- Pydantic Models ---
|
| 22 |
class ModelCard(BaseModel):
|
|
|
|
| 39 |
# --- Core Logic ---
|
| 40 |
def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
|
| 41 |
"""
|
| 42 |
+
Formats the input for Replicate's API, flattening the message history into a
|
| 43 |
+
single 'prompt' string and handling images separately. This is the required
|
| 44 |
+
format for all their current chat/vision models.
|
| 45 |
"""
|
| 46 |
payload = {}
|
|
|
|
| 47 |
prompt_parts = []
|
| 48 |
system_prompt = None
|
| 49 |
image_input = None
|
| 50 |
|
| 51 |
for msg in request.messages:
|
| 52 |
if msg.role == "system":
|
|
|
|
| 53 |
system_prompt = str(msg.content)
|
| 54 |
elif msg.role == "assistant":
|
| 55 |
prompt_parts.append(f"Assistant: {msg.content}")
|
| 56 |
elif msg.role == "user":
|
| 57 |
user_text_content = ""
|
| 58 |
if isinstance(msg.content, list):
|
|
|
|
| 59 |
for item in msg.content:
|
| 60 |
if item.get("type") == "text":
|
| 61 |
user_text_content += item.get("text", "")
|
| 62 |
elif item.get("type") == "image_url":
|
| 63 |
image_url_data = item.get("image_url", {})
|
|
|
|
| 64 |
image_input = image_url_data.get("url")
|
| 65 |
else:
|
| 66 |
user_text_content = str(msg.content)
|
|
|
|
| 67 |
prompt_parts.append(f"User: {user_text_content}")
|
| 68 |
|
|
|
|
| 69 |
prompt_parts.append("Assistant:")
|
|
|
|
|
|
|
| 70 |
payload["prompt"] = "\n\n".join(prompt_parts)
|
| 71 |
|
| 72 |
if system_prompt:
|
| 73 |
payload["system_prompt"] = system_prompt
|
|
|
|
| 74 |
if image_input:
|
| 75 |
payload["image"] = image_input
|
| 76 |
|
|
|
|
| 77 |
if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
|
| 78 |
if request.temperature: payload["temperature"] = request.temperature
|
| 79 |
if request.top_p: payload["top_p"] = request.top_p
|
|
|
|
| 81 |
return payload
|
| 82 |
|
| 83 |
async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
|
| 84 |
+
"""Handles the full streaming lifecycle with robust token parsing."""
|
| 85 |
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
|
| 86 |
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
|
| 87 |
|
|
|
|
| 92 |
prediction = response.json()
|
| 93 |
stream_url = prediction.get("urls", {}).get("stream")
|
| 94 |
prediction_id = prediction.get("id", "stream-unknown")
|
|
|
|
| 95 |
if not stream_url:
|
| 96 |
+
yield json.dumps({"error": {"message": "Model did not return a stream URL."}})
|
| 97 |
+
return
|
|
|
|
| 98 |
except httpx.HTTPStatusError as e:
|
| 99 |
+
error_details = e.response.text
|
| 100 |
+
try:
|
| 101 |
+
error_json = e.response.json()
|
| 102 |
+
error_details = error_json.get("detail", error_details)
|
| 103 |
+
except json.JSONDecodeError: pass
|
| 104 |
+
yield json.dumps({"error": {"message": f"Upstream Error: {error_details}", "type": "replicate_error"}})
|
| 105 |
+
return
|
|
|
|
| 106 |
|
| 107 |
try:
|
| 108 |
async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
|
|
|
|
| 113 |
elif line.startswith("data:"):
|
| 114 |
data = line[len("data:"):].strip()
|
| 115 |
if current_event == "output":
|
| 116 |
+
# --- START OF STREAMING FIX ---
|
| 117 |
+
# Replicate streams tokens that can be plain text or JSON-encoded strings.
|
| 118 |
+
# We need to robustly parse them to preserve spaces correctly.
|
| 119 |
+
content_token = ""
|
| 120 |
+
try:
|
| 121 |
+
# Attempt to parse data as JSON. This handles tokens like "\" Hello\""
|
| 122 |
+
decoded_data = json.loads(data)
|
| 123 |
+
if isinstance(decoded_data, str):
|
| 124 |
+
content_token = decoded_data
|
| 125 |
+
else:
|
| 126 |
+
# It's some other JSON type, convert to string
|
| 127 |
+
content_token = str(decoded_data)
|
| 128 |
+
except json.JSONDecodeError:
|
| 129 |
+
# It's not valid JSON, so it's a plain text token.
|
| 130 |
+
content_token = data
|
| 131 |
+
|
| 132 |
+
if content_token:
|
| 133 |
chunk = {
|
| 134 |
"id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
|
| 135 |
+
"choices": [{"index": 0, "delta": {"content": content_token}, "finish_reason": None}]
|
| 136 |
}
|
| 137 |
yield json.dumps(chunk)
|
| 138 |
+
# --- END OF STREAMING FIX ---
|
| 139 |
elif current_event == "done":
|
| 140 |
break
|
| 141 |
except httpx.ReadTimeout:
|
|
|
|
| 152 |
# --- Endpoints ---
|
| 153 |
@app.get("/v1/models")
|
| 154 |
async def list_models():
|
|
|
|
| 155 |
return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
|
| 156 |
|
| 157 |
@app.post("/v1/chat/completions")
|
| 158 |
async def create_chat_completion(request: OpenAIChatCompletionRequest):
|
|
|
|
| 159 |
if request.model not in SUPPORTED_MODELS:
|
| 160 |
raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
|
| 161 |
|