rkihacker commited on
Commit
1120bba
·
verified ·
1 Parent(s): 2b507a6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +79 -99
main.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  import httpx
3
  import json
4
  import time
5
- from fastapi import FastAPI, Request, HTTPException, Header
 
6
  from fastapi.responses import JSONResponse
7
  from pydantic import BaseModel, Field
8
  from typing import List, Dict, Any, Optional, Union, Literal
@@ -17,13 +18,15 @@ REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
17
  if not REPLICATE_API_TOKEN:
18
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
19
 
 
 
20
  # --- FastAPI App Initialization ---
21
  app = FastAPI(
22
  title="Replicate to OpenAI Compatibility Layer",
23
- version="1.0.0",
24
  )
25
 
26
- # --- Pydantic Models for OpenAI Compatibility ---
27
 
28
  # /v1/models endpoint
29
  class ModelCard(BaseModel):
@@ -61,7 +64,6 @@ class OpenAIChatCompletionRequest(BaseModel):
61
  tool_choice: Optional[Union[str, Dict]] = None
62
 
63
  # --- Replicate Model Mapping ---
64
- # We hardcode the models we want to expose.
65
  SUPPORTED_MODELS = {
66
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
67
  "claude-4.5-haiku": "anthropic/claude-4.5-haiku"
@@ -74,9 +76,8 @@ def format_tools_for_prompt(tools: List[Tool]) -> str:
74
  """Converts OpenAI tools to a string for the system prompt."""
75
  if not tools:
76
  return ""
77
-
78
  prompt = "You have access to the following tools. To use a tool, respond with a JSON object in the following format:\n"
79
- prompt += '{"type": "tool_call", "name": "tool_name", "arguments": {"arg_name": "value"}}\n\n'
80
  prompt += "Available tools:\n"
81
  for tool in tools:
82
  prompt += json.dumps(tool.function.dict(), indent=2) + "\n"
@@ -87,25 +88,24 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
87
  input_data = {}
88
  prompt_parts = []
89
  system_prompt = ""
90
-
91
- # Handle messages, separating system, user, assistant and vision content
92
  image_url = None
 
93
  for message in request.messages:
94
  if message.role == "system":
95
- system_prompt += message.content + "\n"
96
  elif message.role == "user":
97
- if isinstance(message.content, list): # Vision support
98
- for item in message.content:
 
99
  if item.get("type") == "text":
100
  prompt_parts.append(f"User: {item.get('text', '')}")
101
  elif item.get("type") == "image_url":
102
  image_url = item.get("image_url", {}).get("url")
103
  else:
104
- prompt_parts.append(f"User: {message.content}")
105
  elif message.role == "assistant":
106
- prompt_parts.append(f"Assistant: {message.content}")
107
 
108
- # Add tool instructions to system prompt
109
  if request.tools:
110
  tool_prompt = format_tools_for_prompt(request.tools)
111
  system_prompt += "\n" + tool_prompt
@@ -116,75 +116,84 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
116
  if image_url:
117
  input_data["image"] = image_url
118
 
119
- # Map other parameters
120
  if request.temperature is not None:
121
  input_data["temperature"] = request.temperature
122
  if request.top_p is not None:
123
  input_data["top_p"] = request.top_p
124
  if request.max_tokens is not None:
125
- # Replicate uses `max_new_tokens` or `max_tokens` depending on model
126
  input_data["max_new_tokens"] = request.max_tokens
127
 
128
  return input_data
129
 
130
-
131
- async def stream_replicate_response(model_id: str, payload: dict):
132
- """Generator for streaming Replicate responses."""
 
 
133
  url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
134
- headers = {
135
- "Authorization": f"Bearer {REPLICATE_API_TOKEN}",
136
- "Content-Type": "application/json",
137
- }
138
 
139
  async with httpx.AsyncClient(timeout=300) as client:
140
- # 1. Create the prediction and get the stream URL
141
- payload["stream"] = True
142
  try:
143
  response = await client.post(url, headers=headers, json={"input": payload})
144
  response.raise_for_status()
145
  prediction = response.json()
146
- stream_url = prediction.get("urls", {}).get("stream")
147
 
148
- if not stream_url:
149
- yield f"data: {json.dumps({'error': 'Failed to get stream URL'})}\n\n"
 
150
  return
151
  except httpx.HTTPStatusError as e:
152
  yield f"data: {json.dumps({'error': str(e.response.text)})}\n\n"
153
  return
154
-
155
- # 2. Connect to the SSE stream
156
- try:
157
- async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}) as sse:
158
- async for line in sse.aiter_lines():
159
- if line.startswith("data:"):
160
- event_data = line[len("data:"):].strip()
161
- try:
162
- data = json.loads(event_data)
163
- # Format as OpenAI chunk
164
- chunk = {
165
- "id": prediction["id"],
166
- "object": "chat.completion.chunk",
167
- "created": int(time.time()),
168
- "model": model_id,
169
- "choices": [{
170
- "index": 0,
171
- "delta": {"content": data},
172
- "finish_reason": None
173
- }]
174
- }
175
- yield f"data: {json.dumps(chunk)}\n\n"
176
- except json.JSONDecodeError:
177
- continue # Skip non-json lines
178
- except Exception as e:
179
- yield f"data: {json.dumps({'error': f'Streaming error: {str(e)}'})}\n\n"
180
-
181
- # Send the done signal
 
 
 
 
 
 
 
 
 
 
 
 
182
  done_chunk = {
183
  "id": prediction["id"],
184
  "object": "chat.completion.chunk",
185
  "created": int(time.time()),
186
  "model": model_id,
187
- "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
188
  }
189
  yield f"data: {json.dumps(done_chunk)}\n\n"
190
  yield "data: [DONE]\n\n"
@@ -194,15 +203,13 @@ async def stream_replicate_response(model_id: str, payload: dict):
194
 
195
  @app.get("/v1/models", response_model=ModelList)
196
  async def list_models():
197
- """Lists the available models that this compatibility layer supports."""
198
- model_cards = [
199
- ModelCard(id=model_name) for model_name in SUPPORTED_MODELS.keys()
200
- ]
201
  return ModelList(data=model_cards)
202
 
203
  @app.post("/v1/chat/completions")
204
  async def create_chat_completion(request: OpenAIChatCompletionRequest):
205
- """Creates a chat completion, either streaming or synchronous."""
206
  model_key = request.model
207
  if model_key not in SUPPORTED_MODELS:
208
  raise HTTPException(status_code=404, detail=f"Model not found. Supported models: {list(SUPPORTED_MODELS.keys())}")
@@ -211,15 +218,12 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
211
  replicate_input = prepare_replicate_input(request)
212
 
213
  if request.stream:
214
- return EventSourceResponse(stream_replicate_response(replicate_model_id, replicate_input))
 
215
 
216
- # Synchronous request
217
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
218
- headers = {
219
- "Authorization": f"Bearer {REPLICATE_API_TOKEN}",
220
- "Content-Type": "application/json",
221
- "Prefer": "wait=120" # Wait up to 120 seconds for a response
222
- }
223
 
224
  async with httpx.AsyncClient(timeout=150) as client:
225
  try:
@@ -231,47 +235,23 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
231
  if isinstance(output, list):
232
  output = "".join(output)
233
 
234
- # Check for tool call
235
  try:
236
- # A simple check if the output is a JSON for a tool call
237
  tool_call_data = json.loads(output)
238
  if tool_call_data.get("type") == "tool_call":
239
- message_content = None
240
- tool_calls = [{
241
- "id": f"call_{int(time.time())}",
242
- "type": "function",
243
- "function": {
244
- "name": tool_call_data["name"],
245
- "arguments": json.dumps(tool_call_data["arguments"])
246
- }
247
- }]
248
  else:
249
- message_content = output
250
- tool_calls = None
251
  except (json.JSONDecodeError, TypeError):
252
- message_content = output
253
- tool_calls = None
254
 
255
- # Format response in OpenAI format
256
  completion_response = {
257
  "id": prediction["id"],
258
  "object": "chat.completion",
259
  "created": int(time.time()),
260
  "model": model_key,
261
- "choices": [{
262
- "index": 0,
263
- "message": {
264
- "role": "assistant",
265
- "content": message_content,
266
- "tool_calls": tool_calls,
267
- },
268
- "finish_reason": "stop" # Or map from Replicate if available
269
- }],
270
- "usage": { # Note: Replicate doesn't provide token usage in the same way
271
- "prompt_tokens": 0,
272
- "completion_tokens": 0,
273
- "total_tokens": 0
274
- }
275
  }
276
  return JSONResponse(content=completion_response)
277
 
 
2
  import httpx
3
  import json
4
  import time
5
+ import asyncio
6
+ from fastapi import FastAPI, Request, HTTPException
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel, Field
9
  from typing import List, Dict, Any, Optional, Union, Literal
 
18
  if not REPLICATE_API_TOKEN:
19
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
20
 
21
+ POLLING_INTERVAL_SECONDS = 1 # How often to poll for updates
22
+
23
  # --- FastAPI App Initialization ---
24
  app = FastAPI(
25
  title="Replicate to OpenAI Compatibility Layer",
26
+ version="1.1.0 (Polling Strategy)",
27
  )
28
 
29
+ # --- Pydantic Models for OpenAI Compatibility (No Changes) ---
30
 
31
  # /v1/models endpoint
32
  class ModelCard(BaseModel):
 
64
  tool_choice: Optional[Union[str, Dict]] = None
65
 
66
  # --- Replicate Model Mapping ---
 
67
  SUPPORTED_MODELS = {
68
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
69
  "claude-4.5-haiku": "anthropic/claude-4.5-haiku"
 
76
  """Converts OpenAI tools to a string for the system prompt."""
77
  if not tools:
78
  return ""
 
79
  prompt = "You have access to the following tools. To use a tool, respond with a JSON object in the following format:\n"
80
+ prompt += '{"type": "tool_call", "name": "tool_name", "arguments": {"arg_name": "value"}}\n\n"
81
  prompt += "Available tools:\n"
82
  for tool in tools:
83
  prompt += json.dumps(tool.function.dict(), indent=2) + "\n"
 
88
  input_data = {}
89
  prompt_parts = []
90
  system_prompt = ""
 
 
91
  image_url = None
92
+
93
  for message in request.messages:
94
  if message.role == "system":
95
+ system_prompt += str(message.content) + "\n"
96
  elif message.role == "user":
97
+ content = message.content
98
+ if isinstance(content, list):
99
+ for item in content:
100
  if item.get("type") == "text":
101
  prompt_parts.append(f"User: {item.get('text', '')}")
102
  elif item.get("type") == "image_url":
103
  image_url = item.get("image_url", {}).get("url")
104
  else:
105
+ prompt_parts.append(f"User: {str(content)}")
106
  elif message.role == "assistant":
107
+ prompt_parts.append(f"Assistant: {str(message.content)}")
108
 
 
109
  if request.tools:
110
  tool_prompt = format_tools_for_prompt(request.tools)
111
  system_prompt += "\n" + tool_prompt
 
116
  if image_url:
117
  input_data["image"] = image_url
118
 
 
119
  if request.temperature is not None:
120
  input_data["temperature"] = request.temperature
121
  if request.top_p is not None:
122
  input_data["top_p"] = request.top_p
123
  if request.max_tokens is not None:
 
124
  input_data["max_new_tokens"] = request.max_tokens
125
 
126
  return input_data
127
 
128
+ async def stream_replicate_with_polling(model_id: str, payload: dict):
129
+ """
130
+ Creates a prediction and then polls the 'get' URL to stream back results.
131
+ This is a reliable alternative to Replicate's native SSE stream.
132
+ """
133
  url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
134
+ headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
 
 
 
135
 
136
  async with httpx.AsyncClient(timeout=300) as client:
137
+ # 1. Start the prediction
 
138
  try:
139
  response = await client.post(url, headers=headers, json={"input": payload})
140
  response.raise_for_status()
141
  prediction = response.json()
142
+ get_url = prediction.get("urls", {}).get("get")
143
 
144
+ if not get_url:
145
+ error_detail = prediction.get("detail", "Failed to start prediction.")
146
+ yield f"data: {json.dumps({'error': error_detail})}\n\n"
147
  return
148
  except httpx.HTTPStatusError as e:
149
  yield f"data: {json.dumps({'error': str(e.response.text)})}\n\n"
150
  return
151
+
152
+ # 2. Poll the prediction 'get' URL for updates
153
+ previous_output = ""
154
+ status = ""
155
+ while status not in ["succeeded", "failed", "canceled"]:
156
+ await asyncio.sleep(POLLING_INTERVAL_SECONDS)
157
+ try:
158
+ poll_response = await client.get(get_url, headers=headers)
159
+ poll_response.raise_for_status()
160
+ prediction_update = poll_response.json()
161
+ status = prediction_update["status"]
162
+
163
+ if status == "failed":
164
+ error_detail = prediction_update.get("error", "Prediction failed.")
165
+ yield f"data: {json.dumps({'error': error_detail})}\n\n"
166
+ break
167
+
168
+ if "output" in prediction_update and prediction_update["output"] is not None:
169
+ current_output = "".join(prediction_update["output"])
170
+ new_chunk = current_output[len(previous_output):]
171
+
172
+ if new_chunk:
173
+ chunk = {
174
+ "id": prediction["id"],
175
+ "object": "chat.completion.chunk",
176
+ "created": int(time.time()),
177
+ "model": model_id,
178
+ "choices": [{"index": 0, "delta": {"content": new_chunk}, "finish_reason": None}]
179
+ }
180
+ yield f"data: {json.dumps(chunk)}\n\n"
181
+ previous_output = current_output
182
+
183
+ except httpx.HTTPStatusError as e:
184
+ # Don't stop polling on temporary network errors
185
+ print(f"Warning: Polling failed with status {e.response.status_code}, retrying...")
186
+ except Exception as e:
187
+ yield f"data: {json.dumps({'error': f'Polling error: {str(e)}'})}\n\n"
188
+ break
189
+
190
+ # Send the final done signal
191
  done_chunk = {
192
  "id": prediction["id"],
193
  "object": "chat.completion.chunk",
194
  "created": int(time.time()),
195
  "model": model_id,
196
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop" if status == "succeeded" else "error"}]
197
  }
198
  yield f"data: {json.dumps(done_chunk)}\n\n"
199
  yield "data: [DONE]\n\n"
 
203
 
204
  @app.get("/v1/models", response_model=ModelList)
205
  async def list_models():
206
+ """Lists the available models."""
207
+ model_cards = [ModelCard(id=model_name) for model_name in SUPPORTED_MODELS.keys()]
 
 
208
  return ModelList(data=model_cards)
209
 
210
  @app.post("/v1/chat/completions")
211
  async def create_chat_completion(request: OpenAIChatCompletionRequest):
212
+ """Creates a chat completion."""
213
  model_key = request.model
214
  if model_key not in SUPPORTED_MODELS:
215
  raise HTTPException(status_code=404, detail=f"Model not found. Supported models: {list(SUPPORTED_MODELS.keys())}")
 
218
  replicate_input = prepare_replicate_input(request)
219
 
220
  if request.stream:
221
+ # Use the new reliable polling-based streamer
222
+ return EventSourceResponse(stream_replicate_with_polling(replicate_model_id, replicate_input))
223
 
224
+ # Synchronous request (no changes needed here)
225
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
226
+ headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
 
 
 
 
227
 
228
  async with httpx.AsyncClient(timeout=150) as client:
229
  try:
 
235
  if isinstance(output, list):
236
  output = "".join(output)
237
 
238
+ # Basic tool call detection
239
  try:
 
240
  tool_call_data = json.loads(output)
241
  if tool_call_data.get("type") == "tool_call":
242
+ message_content, tool_calls = None, [{"id": f"call_{int(time.time())}", "type": "function", "function": {"name": tool_call_data["name"], "arguments": json.dumps(tool_call_data["arguments"])}}]
 
 
 
 
 
 
 
 
243
  else:
244
+ message_content, tool_calls = output, None
 
245
  except (json.JSONDecodeError, TypeError):
246
+ message_content, tool_calls = output, None
 
247
 
 
248
  completion_response = {
249
  "id": prediction["id"],
250
  "object": "chat.completion",
251
  "created": int(time.time()),
252
  "model": model_key,
253
+ "choices": [{"index": 0, "message": {"role": "assistant", "content": message_content, "tool_calls": tool_calls}, "finish_reason": "stop"}],
254
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
 
 
 
 
 
 
 
 
 
 
 
 
255
  }
256
  return JSONResponse(content=completion_response)
257