rkihacker commited on
Commit
e014ad9
·
verified ·
1 Parent(s): 415ec30

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +105 -74
main.py CHANGED
@@ -16,7 +16,7 @@ if not REPLICATE_API_TOKEN:
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="4.0.0 (Docs Compliant)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
@@ -33,43 +33,43 @@ class OpenAIChatCompletionRequest(BaseModel):
33
  SUPPORTED_MODELS = {
34
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
35
  "claude-4.5-haiku": "anthropic/claude-4.5-haiku"
 
36
  }
37
 
38
  # --- Core Logic ---
39
- def prepare_replicate_input(request: OpenAIChatCompletionRequest, replicate_model_id: str) -> Dict[str, Any]:
40
- """Formats the input specifically for the requested Replicate model."""
 
 
41
  payload = {}
42
 
43
- # Claude on Replicate strictly requires a 'prompt' string, not 'messages' array.
44
- if "anthropic/claude" in replicate_model_id:
45
- prompt_parts = []
46
- system_prompt = None
47
- for msg in request.messages:
48
- if msg.role == "system":
49
- # Extract system prompt if present
50
- system_prompt = str(msg.content)
51
- elif msg.role == "user":
52
- # Handle both simple string content and list content (for potential future vision support)
53
- content = msg.content
54
- if isinstance(content, list):
55
- text_parts = [item.get("text", "") for item in content if item.get("type") == "text"]
56
- content = " ".join(text_parts)
57
- prompt_parts.append(f"User: {content}")
58
- elif msg.role == "assistant":
59
- prompt_parts.append(f"Assistant: {msg.content}")
60
-
61
- # Standard Claude prompting convention
62
- prompt_parts.append("Assistant:")
63
- payload["prompt"] = "\n\n".join(prompt_parts)
64
- if system_prompt:
65
- payload["system_prompt"] = system_prompt
66
-
67
- # Llama 3 and others often support the 'messages' array natively.
68
- else:
69
- # Convert Pydantic models to pure dicts
70
- payload["prompt"] = [msg.dict() for msg in request.messages]
71
 
72
  # Map common OpenAI parameters to Replicate equivalents
 
73
  if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
74
  if request.temperature: payload["temperature"] = request.temperature
75
  if request.top_p: payload["top_p"] = request.top_p
@@ -78,85 +78,116 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest, replicate_mode
78
 
79
  async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
80
  """Handles the full streaming lifecycle using standard Replicate endpoints."""
81
- # 1. Start Prediction specifically at the named model endpoint
82
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
83
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
84
 
85
  async with httpx.AsyncClient(timeout=60.0) as client:
86
  try:
87
- # Explicitly request stream=True in the body, though often implicit
88
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
89
  response.raise_for_status()
90
  prediction = response.json()
91
  stream_url = prediction.get("urls", {}).get("stream")
92
- prediction_id = prediction.get("id")
93
 
94
  if not stream_url:
95
  yield json.dumps({"error": {"message": "Model did not return a stream URL."}})
96
  return
97
 
98
  except httpx.HTTPStatusError as e:
99
- yield json.dumps({"error": {"message": e.response.text, "type": "upstream_error"}})
 
 
 
 
 
 
 
100
  return
101
 
102
- # 2. Connect to the provided Stream URL
103
- async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
104
- current_event = None
105
- async for line in sse.aiter_lines():
106
- if line.startswith("event:"):
107
- current_event = line[len("event:"):].strip()
108
- elif line.startswith("data:"):
109
- data = line[len("data:"):].strip()
110
-
111
- if current_event == "output":
112
- # CRITICAL: Wrap in try/except to ignore empty keep-alive lines that crash standard parsers
113
- try:
114
- # Replicate sometimes sends raw strings, sometimes JSON.
115
- # For chat models, it's usually a raw string token.
116
- # We try to load as JSON first, if it fails, use raw data.
117
- try:
118
- content = json.loads(data)
119
- except json.JSONDecodeError:
120
- content = data
121
-
122
- if content: # Ensure we don't send empty chunks
123
  chunk = {
124
  "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
125
- "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}]
126
  }
127
  yield json.dumps(chunk)
128
- except Exception:
129
- pass # Safely ignore malformed lines
130
-
131
- elif current_event == "done":
132
- break
133
-
134
- # 3. Send final [DONE] event
135
- yield json.dumps({"id": prediction_id, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]})
 
 
 
 
 
 
 
 
 
136
  yield "[DONE]"
137
 
138
  # --- Endpoints ---
139
  @app.get("/v1/models")
140
  async def list_models():
 
141
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
142
 
143
  @app.post("/v1/chat/completions")
144
  async def create_chat_completion(request: OpenAIChatCompletionRequest):
 
145
  if request.model not in SUPPORTED_MODELS:
146
- raise HTTPException(404, f"Model not found. Available: {list(SUPPORTED_MODELS.keys())}")
147
 
148
  replicate_id = SUPPORTED_MODELS[request.model]
149
- replicate_input = prepare_replicate_input(request, replicate_id)
150
 
151
  if request.stream:
152
- return EventSourceResponse(stream_replicate_sse(replicate_id, replicate_input))
 
153
 
154
  # Non-streaming fallback
155
  url = f"https://api.replicate.com/v1/models/{replicate_id}/predictions"
156
- headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=60"}
157
  async with httpx.AsyncClient() as client:
158
- resp = await client.post(url, headers=headers, json={"input": replicate_input})
159
- if resp.is_error: raise HTTPException(resp.status_code, resp.text)
160
- pred = resp.json()
161
- output = "".join(pred.get("output", []))
162
- return {"id": pred["id"], "choices": [{"message": {"role": "assistant", "content": output}, "finish_reason": "stop"}]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="4.1.0 (Context Fixed)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
 
33
  SUPPORTED_MODELS = {
34
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
35
  "claude-4.5-haiku": "anthropic/claude-4.5-haiku"
36
+ # You can add more models here
37
  }
38
 
39
  # --- Core Logic ---
40
+ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
41
+ """
42
+ Formats the input for Replicate API, preserving the conversational context.
43
+ """
44
  payload = {}
45
 
46
+ # --- CONTEXT FIX START ---
47
+ # Modern chat models on Replicate (like Llama 3 and Claude 4.5) expect
48
+ # the 'messages' array directly, just like OpenAI.
49
+ # We no longer need to flatten the conversation into a single prompt string.
50
+
51
+ # Extract system prompt if it exists, as some models take it as a separate parameter.
52
+ messages_for_payload = []
53
+ system_prompt = None
54
+ for msg in request.messages:
55
+ if msg.role == "system":
56
+ # Claude and some other models prefer a dedicated system_prompt field.
57
+ system_prompt = str(msg.content)
58
+ else:
59
+ # Handle user/assistant roles. Convert Pydantic model to a standard dict.
60
+ messages_for_payload.append(msg.dict())
61
+
62
+ # The main input for conversation is the 'messages' array.
63
+ payload["messages"] = messages_for_payload
64
+
65
+ # Add system_prompt to the payload if it was found.
66
+ if system_prompt:
67
+ payload["system_prompt"] = system_prompt
68
+
69
+ # --- CONTEXT FIX END ---
 
 
 
 
70
 
71
  # Map common OpenAI parameters to Replicate equivalents
72
+ # Note: Replicate's parameter for max tokens is often 'max_new_tokens'
73
  if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
74
  if request.temperature: payload["temperature"] = request.temperature
75
  if request.top_p: payload["top_p"] = request.top_p
 
78
 
79
  async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
80
  """Handles the full streaming lifecycle using standard Replicate endpoints."""
81
+ # 1. Start Prediction
82
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
83
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
84
 
85
  async with httpx.AsyncClient(timeout=60.0) as client:
86
  try:
87
+ # Request a streaming prediction
88
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
89
  response.raise_for_status()
90
  prediction = response.json()
91
  stream_url = prediction.get("urls", {}).get("stream")
92
+ prediction_id = prediction.get("id", "stream-unknown")
93
 
94
  if not stream_url:
95
  yield json.dumps({"error": {"message": "Model did not return a stream URL."}})
96
  return
97
 
98
  except httpx.HTTPStatusError as e:
99
+ error_details = e.response.text
100
+ try:
101
+ # Try to parse the error for a cleaner message
102
+ error_json = e.response.json()
103
+ error_details = error_json.get("detail", error_details)
104
+ except json.JSONDecodeError:
105
+ pass # Use raw text if not JSON
106
+ yield json.dumps({"error": {"message": f"Upstream Error: {error_details}", "type": "replicate_error"}})
107
  return
108
 
109
+ # 2. Connect to the provided Stream URL and process Server-Sent Events (SSE)
110
+ try:
111
+ async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
112
+ current_event = None
113
+ async for line in sse.aiter_lines():
114
+ if line.startswith("event:"):
115
+ current_event = line[len("event:"):].strip()
116
+ elif line.startswith("data:"):
117
+ data = line[len("data:"):].strip()
118
+
119
+ if current_event == "output":
120
+ # The 'output' event for chat models sends one token at a time as a plain string.
121
+ # We don't need to parse it as JSON.
122
+ if data: # Ensure we don't send empty chunks
 
 
 
 
 
 
 
123
  chunk = {
124
  "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
125
+ "choices": [{"index": 0, "delta": {"content": data}, "finish_reason": None}]
126
  }
127
  yield json.dumps(chunk)
128
+
129
+ elif current_event == "done":
130
+ # The 'done' event signals the end of the stream.
131
+ break
132
+ except httpx.ReadTimeout:
133
+ # Handle cases where the stream times out
134
+ yield json.dumps({"error": {"message": "Stream timed out.", "type": "timeout_error"}})
135
+ return
136
+
137
+
138
+ # 3. Send the final termination chunk in OpenAI format
139
+ final_chunk = {
140
+ "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
141
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
142
+ }
143
+ yield json.dumps(final_chunk)
144
+ # Some clients (like curl) expect a final "[DONE]" message to close the connection.
145
  yield "[DONE]"
146
 
147
  # --- Endpoints ---
148
  @app.get("/v1/models")
149
  async def list_models():
150
+ """Lists the currently supported models."""
151
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
152
 
153
  @app.post("/v1/chat/completions")
154
  async def create_chat_completion(request: OpenAIChatCompletionRequest):
155
+ """Handles chat completion requests, streaming or non-streaming."""
156
  if request.model not in SUPPORTED_MODELS:
157
+ raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
158
 
159
  replicate_id = SUPPORTED_MODELS[request.model]
160
+ replicate_input = prepare_replicate_input(request)
161
 
162
  if request.stream:
163
+ # Return a streaming response
164
+ return EventSourceResponse(stream_replicate_sse(replicate_id, replicate_input), media_type="text/event-stream")
165
 
166
  # Non-streaming fallback
167
  url = f"https://api.replicate.com/v1/models/{replicate_id}/predictions"
168
+ headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"} # Increased wait time
169
  async with httpx.AsyncClient() as client:
170
+ try:
171
+ resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=130.0)
172
+ resp.raise_for_status()
173
+ pred = resp.json()
174
+ # The output of chat models is typically a list of strings (tokens)
175
+ output = "".join(pred.get("output", []))
176
+ return {
177
+ "id": pred.get("id"),
178
+ "object": "chat.completion",
179
+ "created": int(time.time()),
180
+ "model": request.model,
181
+ "choices": [{
182
+ "index": 0,
183
+ "message": {"role": "assistant", "content": output},
184
+ "finish_reason": "stop"
185
+ }],
186
+ "usage": { # Placeholder usage object
187
+ "prompt_tokens": 0,
188
+ "completion_tokens": 0,
189
+ "total_tokens": 0
190
+ }
191
+ }
192
+ except httpx.HTTPStatusError as e:
193
+ raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")