rkihacker commited on
Commit
bff4d10
·
verified ·
1 Parent(s): e58046c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +26 -44
main.py CHANGED
@@ -23,7 +23,7 @@ POLLING_INTERVAL_SECONDS = 1 # How often to poll for updates
23
  # --- FastAPI App Initialization ---
24
  app = FastAPI(
25
  title="Replicate to OpenAI Compatibility Layer",
26
- version="1.1.1 (SyntaxError Fixed)",
27
  )
28
 
29
  # --- Pydantic Models for OpenAI Compatibility ---
@@ -71,12 +71,9 @@ SUPPORTED_MODELS = {
71
  # --- Helper Functions ---
72
 
73
  def format_tools_for_prompt(tools: List[Tool]) -> str:
74
- """Converts OpenAI tools to a string for the system prompt."""
75
  if not tools:
76
  return ""
77
-
78
  prompt = "You have access to the following tools. To use a tool, respond with a JSON object in the following format:\n"
79
- # *** THIS IS THE CORRECTED LINE ***
80
  prompt += '{"type": "tool_call", "name": "tool_name", "arguments": {"arg_name": "value"}}\n\n'
81
  prompt += "Available tools:\n"
82
  for tool in tools:
@@ -84,7 +81,6 @@ def format_tools_for_prompt(tools: List[Tool]) -> str:
84
  return prompt
85
 
86
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
87
- """Prepares the input payload for the Replicate API."""
88
  input_data = {}
89
  prompt_parts = []
90
  system_prompt = ""
@@ -127,13 +123,14 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
127
 
128
  async def stream_replicate_with_polling(model_id: str, payload: dict):
129
  """
130
- Creates a prediction and then polls the 'get' URL to stream back results.
 
131
  """
132
  url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
133
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
134
 
135
  async with httpx.AsyncClient(timeout=300) as client:
136
- # 1. Start the prediction
137
  try:
138
  response = await client.post(url, headers=headers, json={"input": payload})
139
  response.raise_for_status()
@@ -142,13 +139,14 @@ async def stream_replicate_with_polling(model_id: str, payload: dict):
142
 
143
  if not get_url:
144
  error_detail = prediction.get("detail", "Failed to start prediction.")
145
- yield f"data: {json.dumps({'error': error_detail})}\n\n"
 
146
  return
147
  except httpx.HTTPStatusError as e:
148
- yield f"data: {json.dumps({'error': str(e.response.text)})}\n\n"
 
149
  return
150
 
151
- # 2. Poll the prediction 'get' URL for updates
152
  previous_output = ""
153
  status = ""
154
  while status not in ["succeeded", "failed", "canceled"]:
@@ -161,53 +159,44 @@ async def stream_replicate_with_polling(model_id: str, payload: dict):
161
 
162
  if status == "failed":
163
  error_detail = prediction_update.get("error", "Prediction failed.")
164
- yield f"data: {json.dumps({'error': error_detail})}\n\n"
 
165
  break
166
 
167
  if "output" in prediction_update and prediction_update["output"] is not None:
168
  current_output = "".join(prediction_update["output"])
169
- new_chunk = current_output[len(previous_output):]
170
 
171
- if new_chunk:
172
  chunk = {
173
- "id": prediction["id"],
174
- "object": "chat.completion.chunk",
175
- "created": int(time.time()),
176
- "model": model_id,
177
- "choices": [{"index": 0, "delta": {"content": new_chunk}, "finish_reason": None}]
178
  }
179
- yield f"data: {json.dumps(chunk)}\n\n"
180
  previous_output = current_output
181
 
182
- except httpx.HTTPStatusError as e:
183
- print(f"Warning: Polling failed with status {e.response.status_code}, retrying...")
184
  except Exception as e:
185
- yield f"data: {json.dumps({'error': f'Polling error: {str(e)}'})}\n\n"
 
186
  break
187
 
188
  # Send the final done signal
189
  done_chunk = {
190
- "id": prediction["id"],
191
- "object": "chat.completion.chunk",
192
- "created": int(time.time()),
193
- "model": model_id,
194
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop" if status == "succeeded" else "error"}]
195
  }
196
- yield f"data: {json.dumps(done_chunk)}\n\n"
197
- yield "data: [DONE]\n\n"
198
 
199
 
200
  # --- API Endpoints ---
201
 
202
  @app.get("/v1/models", response_model=ModelList)
203
  async def list_models():
204
- """Lists the available models."""
205
- model_cards = [ModelCard(id=model_name) for model_name in SUPPORTED_MODELS.keys()]
206
- return ModelList(data=model_cards)
207
 
208
  @app.post("/v1/chat/completions")
209
  async def create_chat_completion(request: OpenAIChatCompletionRequest):
210
- """Creates a chat completion."""
211
  model_key = request.model
212
  if model_key not in SUPPORTED_MODELS:
213
  raise HTTPException(status_code=404, detail=f"Model not found. Supported models: {list(SUPPORTED_MODELS.keys())}")
@@ -228,11 +217,8 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
228
  response.raise_for_status()
229
  prediction = response.json()
230
 
231
- output = prediction.get("output", "")
232
- if isinstance(output, list):
233
- output = "".join(output)
234
-
235
- # Basic tool call detection
236
  try:
237
  tool_call_data = json.loads(output)
238
  if tool_call_data.get("type") == "tool_call":
@@ -242,15 +228,11 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
242
  except (json.JSONDecodeError, TypeError):
243
  message_content, tool_calls = output, None
244
 
245
- completion_response = {
246
- "id": prediction["id"],
247
- "object": "chat.completion",
248
- "created": int(time.time()),
249
- "model": model_key,
250
  "choices": [{"index": 0, "message": {"role": "assistant", "content": message_content, "tool_calls": tool_calls}, "finish_reason": "stop"}],
251
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
252
- }
253
- return JSONResponse(content=completion_response)
254
 
255
  except httpx.HTTPStatusError as e:
256
  raise HTTPException(status_code=e.response.status_code, detail=e.response.text)
 
23
  # --- FastAPI App Initialization ---
24
  app = FastAPI(
25
  title="Replicate to OpenAI Compatibility Layer",
26
+ version="1.2.0 (Streaming Fixed)",
27
  )
28
 
29
  # --- Pydantic Models for OpenAI Compatibility ---
 
71
  # --- Helper Functions ---
72
 
73
  def format_tools_for_prompt(tools: List[Tool]) -> str:
 
74
  if not tools:
75
  return ""
 
76
  prompt = "You have access to the following tools. To use a tool, respond with a JSON object in the following format:\n"
 
77
  prompt += '{"type": "tool_call", "name": "tool_name", "arguments": {"arg_name": "value"}}\n\n'
78
  prompt += "Available tools:\n"
79
  for tool in tools:
 
81
  return prompt
82
 
83
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
 
84
  input_data = {}
85
  prompt_parts = []
86
  system_prompt = ""
 
123
 
124
  async def stream_replicate_with_polling(model_id: str, payload: dict):
125
  """
126
+ Creates a prediction and polls the 'get' URL to stream back results.
127
+ Yields raw JSON strings for EventSourceResponse to handle.
128
  """
129
  url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
130
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
131
 
132
  async with httpx.AsyncClient(timeout=300) as client:
133
+ prediction = None
134
  try:
135
  response = await client.post(url, headers=headers, json={"input": payload})
136
  response.raise_for_status()
 
139
 
140
  if not get_url:
141
  error_detail = prediction.get("detail", "Failed to start prediction.")
142
+ error_chunk = {"error": {"message": error_detail, "type": "api_error", "code": 500}}
143
+ yield json.dumps(error_chunk)
144
  return
145
  except httpx.HTTPStatusError as e:
146
+ error_chunk = {"error": {"message": e.response.text, "type": "api_error", "code": e.response.status_code}}
147
+ yield json.dumps(error_chunk)
148
  return
149
 
 
150
  previous_output = ""
151
  status = ""
152
  while status not in ["succeeded", "failed", "canceled"]:
 
159
 
160
  if status == "failed":
161
  error_detail = prediction_update.get("error", "Prediction failed.")
162
+ chunk = {"choices": [{"delta": {"content": f"\n\n[ERROR: {error_detail}]"}, "finish_reason": "error"}]}
163
+ yield json.dumps(chunk)
164
  break
165
 
166
  if "output" in prediction_update and prediction_update["output"] is not None:
167
  current_output = "".join(prediction_update["output"])
168
+ new_chunk_text = current_output[len(previous_output):]
169
 
170
+ if new_chunk_text:
171
  chunk = {
172
+ "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
173
+ "choices": [{"index": 0, "delta": {"content": new_chunk_text}, "finish_reason": None}]
 
 
 
174
  }
175
+ yield json.dumps(chunk) # *** FIX: Yield raw JSON string
176
  previous_output = current_output
177
 
 
 
178
  except Exception as e:
179
+ error_chunk = {"error": {"message": f"Polling error: {str(e)}", "type": "internal_error", "code": 500}}
180
+ yield json.dumps(error_chunk)
181
  break
182
 
183
  # Send the final done signal
184
  done_chunk = {
185
+ "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
 
 
 
186
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop" if status == "succeeded" else "error"}]
187
  }
188
+ yield json.dumps(done_chunk) # *** FIX: Yield raw JSON string
189
+ yield "[DONE]" # *** FIX: Yield the special [DONE] marker
190
 
191
 
192
  # --- API Endpoints ---
193
 
194
  @app.get("/v1/models", response_model=ModelList)
195
  async def list_models():
196
+ return ModelList(data=[ModelCard(id=model_name) for model_name in SUPPORTED_MODELS.keys()])
 
 
197
 
198
  @app.post("/v1/chat/completions")
199
  async def create_chat_completion(request: OpenAIChatCompletionRequest):
 
200
  model_key = request.model
201
  if model_key not in SUPPORTED_MODELS:
202
  raise HTTPException(status_code=404, detail=f"Model not found. Supported models: {list(SUPPORTED_MODELS.keys())}")
 
217
  response.raise_for_status()
218
  prediction = response.json()
219
 
220
+ output = "".join(prediction.get("output", []))
221
+
 
 
 
222
  try:
223
  tool_call_data = json.loads(output)
224
  if tool_call_data.get("type") == "tool_call":
 
228
  except (json.JSONDecodeError, TypeError):
229
  message_content, tool_calls = output, None
230
 
231
+ return JSONResponse(content={
232
+ "id": prediction["id"], "object": "chat.completion", "created": int(time.time()), "model": model_key,
 
 
 
233
  "choices": [{"index": 0, "message": {"role": "assistant", "content": message_content, "tool_calls": tool_calls}, "finish_reason": "stop"}],
234
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
235
+ })
 
236
 
237
  except httpx.HTTPStatusError as e:
238
  raise HTTPException(status_code=e.response.status_code, detail=e.response.text)