rkihacker commited on
Commit
ea53c08
·
verified ·
1 Parent(s): 97aa2c2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +69 -113
main.py CHANGED
@@ -3,7 +3,7 @@ import httpx
3
  import json
4
  import time
5
  import asyncio
6
- from fastapi import FastAPI, Request, HTTPException
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel, Field
9
  from typing import List, Dict, Any, Optional, Union, Literal
@@ -18,15 +18,10 @@ REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
18
  if not REPLICATE_API_TOKEN:
19
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
20
 
21
- # *** THE FIX IS HERE ***
22
- # Reduced from 1.0 to 0.05 for smoother, more frequent streaming updates.
23
- # This makes the polling fast enough to appear like real-time token streaming.
24
- POLLING_INTERVAL_SECONDS = 0.05
25
-
26
  # --- FastAPI App Initialization ---
27
  app = FastAPI(
28
  title="Replicate to OpenAI Compatibility Layer",
29
- version="1.3.0 (Smooth Streaming)",
30
  )
31
 
32
  # --- Pydantic Models for OpenAI Compatibility ---
@@ -73,122 +68,92 @@ SUPPORTED_MODELS = {
73
 
74
  # --- Helper Functions ---
75
 
76
- def format_tools_for_prompt(tools: List[Tool]) -> str:
77
- if not tools:
78
- return ""
79
- prompt = "You have access to the following tools. To use a tool, respond with a JSON object in the following format:\n"
80
- prompt += '{"type": "tool_call", "name": "tool_name", "arguments": {"arg_name": "value"}}\n\n'
81
- prompt += "Available tools:\n"
82
- for tool in tools:
83
- prompt += json.dumps(tool.function.dict(), indent=2) + "\n"
84
- return prompt
85
-
86
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
87
- input_data = {}
88
- prompt_parts = []
89
- system_prompt = ""
90
- image_url = None
91
-
92
- for message in request.messages:
93
- if message.role == "system":
94
- system_prompt += str(message.content) + "\n"
95
- elif message.role == "user":
96
- content = message.content
97
- if isinstance(content, list):
98
- for item in content:
99
- if item.get("type") == "text":
100
- prompt_parts.append(f"User: {item.get('text', '')}")
101
- elif item.get("type") == "image_url":
102
- image_url = item.get("image_url", {}).get("url")
103
- else:
104
- prompt_parts.append(f"User: {str(content)}")
105
- elif message.role == "assistant":
106
- prompt_parts.append(f"Assistant: {str(message.content)}")
107
-
108
- if request.tools:
109
- tool_prompt = format_tools_for_prompt(request.tools)
110
- system_prompt += "\n" + tool_prompt
111
-
112
- # Add final turn for the assistant to respond
113
- prompt_parts.append("Assistant:")
114
-
115
- input_data["prompt"] = "\n".join(prompt_parts)
116
- if system_prompt:
117
- input_data["system_prompt"] = system_prompt
118
- if image_url:
119
- input_data["image"] = image_url
120
 
 
 
 
 
 
 
 
121
  if request.temperature is not None:
122
- input_data["temperature"] = request.temperature
123
  if request.top_p is not None:
124
- input_data["top_p"] = request.top_p
125
- if request.max_tokens is not None:
126
- input_data["max_new_tokens"] = request.max_tokens
127
-
128
- return input_data
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- async def stream_replicate_with_polling(model_id: str, payload: dict):
131
  """
132
- Creates a prediction and polls the 'get' URL to stream back results.
133
- Yields raw JSON strings for EventSourceResponse to handle.
134
  """
135
  url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
136
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
137
 
138
  async with httpx.AsyncClient(timeout=300) as client:
139
- prediction = None
140
  try:
141
- response = await client.post(url, headers=headers, json={"input": payload})
 
142
  response.raise_for_status()
143
  prediction = response.json()
144
- get_url = prediction.get("urls", {}).get("get")
145
 
146
- if not get_url:
147
- error_detail = prediction.get("detail", "Failed to start prediction.")
148
- error_chunk = {"error": {"message": error_detail, "type": "api_error", "code": 500}}
149
- yield json.dumps(error_chunk)
150
  return
151
  except httpx.HTTPStatusError as e:
152
- error_chunk = {"error": {"message": e.response.text, "type": "api_error", "code": e.response.status_code}}
153
- yield json.dumps(error_chunk)
154
  return
155
 
156
- previous_output = ""
157
- status = ""
158
- while status not in ["succeeded", "failed", "canceled"]:
159
- await asyncio.sleep(POLLING_INTERVAL_SECONDS)
160
- try:
161
- poll_response = await client.get(get_url, headers=headers)
162
- poll_response.raise_for_status()
163
- prediction_update = poll_response.json()
164
- status = prediction_update["status"]
165
-
166
- if status == "failed":
167
- error_detail = prediction_update.get("error", "Prediction failed.")
168
- chunk = {"choices": [{"delta": {"content": f"\n\n[ERROR: {error_detail}]"}, "finish_reason": "error"}]}
169
- yield json.dumps(chunk)
170
- break
171
-
172
- if "output" in prediction_update and prediction_update["output"] is not None:
173
- current_output = "".join(prediction_update["output"])
174
- new_chunk_text = current_output[len(previous_output):]
175
-
176
- if new_chunk_text:
177
- chunk = {
178
- "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
179
- "choices": [{"index": 0, "delta": {"content": new_chunk_text}, "finish_reason": None}]
180
- }
181
- yield json.dumps(chunk)
182
- previous_output = current_output
183
-
184
- except Exception as e:
185
- error_chunk = {"error": {"message": f"Polling error: {str(e)}", "type": "internal_error", "code": 500}}
186
- yield json.dumps(error_chunk)
187
- break
188
-
189
  done_chunk = {
190
  "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
191
- "choices": [{"index": 0, "delta": {}, "finish_reason": "stop" if status == "succeeded" else "error"}]
192
  }
193
  yield json.dumps(done_chunk)
194
  yield "[DONE]"
@@ -210,7 +175,7 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
210
  replicate_input = prepare_replicate_input(request)
211
 
212
  if request.stream:
213
- return EventSourceResponse(stream_replicate_with_polling(replicate_model_id, replicate_input))
214
 
215
  # Synchronous request
216
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
@@ -224,18 +189,9 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
224
 
225
  output = "".join(prediction.get("output", []))
226
 
227
- try:
228
- tool_call_data = json.loads(output)
229
- if tool_call_data.get("type") == "tool_call":
230
- message_content, tool_calls = None, [{"id": f"call_{int(time.time())}", "type": "function", "function": {"name": tool_call_data["name"], "arguments": json.dumps(tool_call_data["arguments"])}}]
231
- else:
232
- message_content, tool_calls = output, None
233
- except (json.JSONDecodeError, TypeError):
234
- message_content, tool_calls = output, None
235
-
236
  return JSONResponse(content={
237
  "id": prediction["id"], "object": "chat.completion", "created": int(time.time()), "model": model_key,
238
- "choices": [{"index": 0, "message": {"role": "assistant", "content": message_content, "tool_calls": tool_calls}, "finish_reason": "stop"}],
239
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
240
  })
241
 
 
3
  import json
4
  import time
5
  import asyncio
6
+ from fastapi import FastAPI, HTTPException
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel, Field
9
  from typing import List, Dict, Any, Optional, Union, Literal
 
18
  if not REPLICATE_API_TOKEN:
19
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
20
 
 
 
 
 
 
21
  # --- FastAPI App Initialization ---
22
  app = FastAPI(
23
  title="Replicate to OpenAI Compatibility Layer",
24
+ version="2.0.0 (Native Streaming & Context Fixed)",
25
  )
26
 
27
  # --- Pydantic Models for OpenAI Compatibility ---
 
68
 
69
  # --- Helper Functions ---
70
 
 
 
 
 
 
 
 
 
 
 
71
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
72
+ """
73
+ Prepares the input payload for Replicate's chat models.
74
+ This now correctly passes the messages array for context.
75
+ """
76
+ # Convert Pydantic message objects to a list of dictionaries
77
+ messages_for_replicate = [msg.dict() for msg in request.messages]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ payload = {
80
+ "messages": messages_for_replicate
81
+ }
82
+
83
+ # Add other compatible parameters
84
+ if request.max_tokens is not None:
85
+ payload["max_new_tokens"] = request.max_tokens
86
  if request.temperature is not None:
87
+ payload["temperature"] = request.temperature
88
  if request.top_p is not None:
89
+ payload["top_p"] = request.top_p
90
+
91
+ # Vision support: Find image URL in the last user message if present
92
+ last_user_message = next((m for m in reversed(request.messages) if m.role == 'user'), None)
93
+ if last_user_message and isinstance(last_user_message.content, list):
94
+ for item in last_user_message.content:
95
+ if item.get("type") == "image_url":
96
+ payload["image"] = item.get("image_url", {}).get("url")
97
+ # Reformat messages to be a simple prompt string for vision models if needed,
98
+ # as some might not support the `messages` format with images.
99
+ # For Claude Haiku, a prompt string is more reliable with images.
100
+ if "claude" in request.model:
101
+ text_prompts = [item.get('text', '') for item in last_user_message.content if item.get('type') == 'text']
102
+ payload["prompt"] = " ".join(text_prompts)
103
+ del payload["messages"]
104
+ break
105
+
106
+ return payload
107
 
108
+ async def stream_replicate_native_sse(model_id: str, payload: dict):
109
  """
110
+ Connects to Replicate's native SSE stream for true token-by-token streaming.
 
111
  """
112
  url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
113
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
114
 
115
  async with httpx.AsyncClient(timeout=300) as client:
116
+ # 1. Create the prediction to get the stream URL
117
  try:
118
+ # Add stream=True to the outer payload for Replicate
119
+ response = await client.post(url, headers=headers, json={"input": payload, "stream": True})
120
  response.raise_for_status()
121
  prediction = response.json()
122
+ stream_url = prediction.get("urls", {}).get("stream")
123
 
124
+ if not stream_url:
125
+ error_detail = prediction.get("detail", "Failed to get stream URL.")
126
+ yield json.dumps({"error": {"message": error_detail}})
 
127
  return
128
  except httpx.HTTPStatusError as e:
129
+ yield json.dumps({"error": {"message": e.response.text}})
 
130
  return
131
 
132
+ # 2. Connect to the SSE stream and yield OpenAI-compatible chunks
133
+ try:
134
+ async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}) as sse:
135
+ sse.raise_for_status()
136
+ current_event = ""
137
+ async for line in sse.aiter_lines():
138
+ if line.startswith("event:"):
139
+ current_event = line[len("event:"):].strip()
140
+ elif line.startswith("data:"):
141
+ data = line[len("data:"):].strip()
142
+ if current_event == "output":
143
+ chunk = {
144
+ "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
145
+ "choices": [{"index": 0, "delta": {"content": json.loads(data)}, "finish_reason": None}]
146
+ }
147
+ yield json.dumps(chunk)
148
+ elif current_event == "done":
149
+ break # Exit loop when done event is received
150
+ except Exception as e:
151
+ yield json.dumps({"error": {"message": f"Streaming error: {str(e)}"}})
152
+
153
+ # 3. Send the final DONE chunk
 
 
 
 
 
 
 
 
 
 
 
154
  done_chunk = {
155
  "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
156
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
157
  }
158
  yield json.dumps(done_chunk)
159
  yield "[DONE]"
 
175
  replicate_input = prepare_replicate_input(request)
176
 
177
  if request.stream:
178
+ return EventSourceResponse(stream_replicate_native_sse(replicate_model_id, replicate_input))
179
 
180
  # Synchronous request
181
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
 
189
 
190
  output = "".join(prediction.get("output", []))
191
 
 
 
 
 
 
 
 
 
 
192
  return JSONResponse(content={
193
  "id": prediction["id"], "object": "chat.completion", "created": int(time.time()), "model": model_key,
194
+ "choices": [{"index": 0, "message": {"role": "assistant", "content": output}, "finish_reason": "stop"}],
195
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
196
  })
197