rkihacker commited on
Commit
0f99721
·
verified ·
1 Parent(s): 580cccc

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -109
main.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import httpx
4
  import json
@@ -17,7 +16,7 @@ if not REPLICATE_API_TOKEN:
17
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
18
 
19
  # FastAPI Init
20
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.1 (Spacing Fixed)")
21
 
22
  # --- Pydantic Models ---
23
  class ModelCard(BaseModel):
@@ -25,7 +24,7 @@ class ModelCard(BaseModel):
25
  class ModelList(BaseModel):
26
  object: str = "list"; data: List[ModelCard] = []
27
  class ChatMessage(BaseModel):
28
- role: Literal["system", "user", "assistant", "tool"]; content: Union[str, List[Dict[str, Any]]]; name: Optional[str] = None
29
  class FunctionDefinition(BaseModel):
30
  name: str; description: Optional[str] = None; parameters: Optional[Dict[str, Any]] = None
31
  class ToolDefinition(BaseModel):
@@ -57,33 +56,30 @@ class ChatCompletionChunk(BaseModel):
57
  # --- Supported Models ---
58
  SUPPORTED_MODELS = {
59
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
60
- "claude-4.5-haiku": "anthropic/claude-4.5-haiku",
61
- "claude-4.5-sonnet": "anthropic/claude-4.5-sonnet",
62
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
63
  }
64
 
65
  # --- Core Logic ---
66
  def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
67
- """Convert OpenAI messages to Replicate-compatible format with function calling support."""
68
  prompt_parts = []
69
  system_prompt = None
70
  image_input = None
71
 
72
- # Add functions to system prompt if provided
73
  if functions:
74
- functions_text = "\n\nAvailable functions:\n"
75
  for func in functions:
76
- functions_text += f"- {func.name}: {func.description or 'No description'}\n"
77
- if func.parameters:
78
- functions_text += f" Parameters: {json.dumps(func.parameters)}\n"
79
  prompt_parts.append(functions_text)
80
 
81
  for msg in messages:
82
  if msg.role == "system":
83
  system_prompt = str(msg.content)
84
  elif msg.role == "assistant":
85
- # Handle tool calls in assistant messages
86
- if hasattr(msg, 'tool_calls') and msg.tool_calls:
87
  tool_calls_text = "\nTool calls:\n"
88
  for tool_call in msg.tool_calls:
89
  tool_calls_text += f"- {tool_call.function.name}({tool_call.function.arguments})\n"
@@ -91,7 +87,6 @@ def format_messages_for_replicate(messages: List[ChatMessage], functions: Option
91
  else:
92
  prompt_parts.append(f"Assistant: {msg.content}")
93
  elif msg.role == "tool":
94
- # Handle tool responses
95
  prompt_parts.append(f"Tool Response: {msg.content}")
96
  elif msg.role == "user":
97
  user_text_content = ""
@@ -106,8 +101,7 @@ def format_messages_for_replicate(messages: List[ChatMessage], functions: Option
106
  user_text_content = str(msg.content)
107
  prompt_parts.append(f"User: {user_text_content}")
108
 
109
- # Fix: Don't add trailing space, let model decide spacing
110
- prompt_parts.append("Assistant:")
111
  return {
112
  "prompt": "\n\n".join(prompt_parts),
113
  "system_prompt": system_prompt,
@@ -115,11 +109,8 @@ def format_messages_for_replicate(messages: List[ChatMessage], functions: Option
115
  }
116
 
117
  def parse_function_call(content: str) -> Optional[Dict[str, Any]]:
118
- """Parse function call from model response."""
119
  try:
120
- # Look for JSON-like function call patterns
121
  if "function_call" in content or ("name" in content and "arguments" in content):
122
- # Extract JSON part
123
  start = content.find("{")
124
  end = content.rfind("}") + 1
125
  if start != -1 and end > start:
@@ -132,7 +123,6 @@ def parse_function_call(content: str) -> Optional[Dict[str, Any]]:
132
  return None
133
 
134
  async def stream_replicate_response(replicate_model_id: str, input_payload: dict, request_id: str):
135
- """Stream response with full OpenAI compatibility including tool calls."""
136
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
137
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
138
 
@@ -151,9 +141,7 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
151
  return
152
  except httpx.HTTPStatusError as e:
153
  error_details = e.response.text
154
- try:
155
- error_json = e.response.json()
156
- error_details = error_json.get("detail", error_details)
157
  except json.JSONDecodeError: pass
158
  yield f"data: {json.dumps({'error': {'message': f'Upstream Error: {error_details}', 'type': 'replicate_error'}})}\n\n"
159
  return
@@ -172,80 +160,39 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
172
  elif line.startswith("data:") and current_event == "output":
173
  raw_data = line[5:].strip()
174
  if not raw_data: continue
175
-
176
- content_token = ""
177
  try:
178
  content_token = json.loads(raw_data)
179
  except (json.JSONDecodeError, TypeError):
180
  content_token = raw_data
181
 
182
- # Fix: Handle spacing properly - don't prepend space to first token
 
 
183
  if first_token:
184
  content_token = content_token.lstrip()
185
- first_token = False
 
 
186
 
187
  accumulated_content += content_token
188
  completion_tokens += 1
189
 
190
- # Check for function calls in accumulated content
191
  function_call = parse_function_call(accumulated_content)
192
  if function_call:
193
- # Send tool call chunk
194
- tool_call = ToolCall(
195
- id=f"call_{int(time.time())}",
196
- function=FunctionCall(
197
- name=function_call["name"],
198
- arguments=function_call["arguments"]
199
- )
200
- )
201
- chunk = ChatCompletionChunk(
202
- id=request_id,
203
- created=int(time.time()),
204
- model=replicate_model_id,
205
- choices=[ChoiceDelta(
206
- index=0,
207
- delta=DeltaMessage(tool_calls=[tool_call]),
208
- finish_reason=None
209
- )]
210
- )
211
  yield f"data: {chunk.json()}\n\n"
212
  else:
213
- # Send regular content chunk
214
- chunk = ChatCompletionChunk(
215
- id=request_id,
216
- created=int(time.time()),
217
- model=replicate_model_id,
218
- choices=[ChoiceDelta(
219
- index=0,
220
- delta=DeltaMessage(content=content_token),
221
- finish_reason=None
222
- )]
223
- )
224
- yield f"data: {chunk.json()}\n\n"
225
 
226
  elif current_event == "done":
227
- # Send final usage chunk
228
  end_time = time.time()
229
- inference_time = end_time - start_time
230
-
231
- usage = Usage(
232
- prompt_tokens=prompt_tokens,
233
- completion_tokens=completion_tokens,
234
- total_tokens=prompt_tokens + completion_tokens,
235
- inference_time=round(inference_time, 3)
236
- )
237
-
238
- usage_chunk = ChatCompletionChunk(
239
- id=request_id,
240
- created=int(time.time()),
241
- model=replicate_model_id,
242
- choices=[ChoiceDelta(
243
- index=0,
244
- delta=DeltaMessage(),
245
- finish_reason="stop"
246
- )],
247
- usage=usage
248
- )
249
  yield f"data: {usage_chunk.json()}\n\n"
250
  break
251
 
@@ -265,7 +212,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
265
  if request.model not in SUPPORTED_MODELS:
266
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
267
 
268
- # Format messages for Replicate
269
  formatted = format_messages_for_replicate(request.messages, request.functions)
270
  replicate_input = {
271
  "prompt": formatted["prompt"],
@@ -273,21 +220,19 @@ async def create_chat_completion(request: ChatCompletionRequest):
273
  "temperature": request.temperature or 0.7,
274
  "top_p": request.top_p or 1.0
275
  }
276
- if formatted["system_prompt"]:
277
- replicate_input["system_prompt"] = formatted["system_prompt"]
278
- if formatted["image"]:
279
- replicate_input["image"] = formatted["image"]
280
 
281
  request_id = f"chatcmpl-{int(time.time())}"
282
 
283
  if request.stream:
284
  return StreamingResponse(
285
- stream_replicate_response(SUPPORTED_MODELS[request.model], replicate_input, request_id),
286
  media_type="text/event-stream"
287
  )
288
 
289
  # Non-streaming response
290
- url = f"https://api.replicate.com/v1/models/{SUPPORTED_MODELS[request.model]}/predictions"
291
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
292
  start_time = time.time()
293
 
@@ -298,27 +243,21 @@ async def create_chat_completion(request: ChatCompletionRequest):
298
  pred = resp.json()
299
  output = "".join(pred.get("output", []))
300
 
301
- # Fix: Clean up leading/trailing whitespace
302
- output = output.strip()
303
 
304
- # Calculate timing and tokens
305
  end_time = time.time()
306
- inference_time = end_time - start_time
307
  prompt_tokens = len(replicate_input.get("prompt", "")) // 4
308
  completion_tokens = len(output) // 4
309
 
310
- # Parse function call if present
311
  tool_calls = None
 
 
 
312
  function_call = parse_function_call(output)
313
  if function_call:
314
- tool_call = ToolCall(
315
- id=f"call_{int(time.time())}",
316
- function=FunctionCall(
317
- name=function_call["name"],
318
- arguments=function_call["arguments"]
319
- )
320
- )
321
- tool_calls = [tool_call]
322
 
323
  return ChatCompletion(
324
  id=request_id,
@@ -326,18 +265,14 @@ async def create_chat_completion(request: ChatCompletionRequest):
326
  model=request.model,
327
  choices=[Choice(
328
  index=0,
329
- message=ChatMessage(
330
- role="assistant",
331
- content=output if not function_call else None,
332
- tool_calls=tool_calls
333
- ),
334
- finish_reason="tool_calls" if function_call else "stop"
335
  )],
336
  usage=Usage(
337
  prompt_tokens=prompt_tokens,
338
  completion_tokens=completion_tokens,
339
  total_tokens=prompt_tokens + completion_tokens,
340
- inference_time=round(inference_time, 3)
341
  )
342
  )
343
  except httpx.HTTPStatusError as e:
@@ -347,14 +282,13 @@ async def create_chat_completion(request: ChatCompletionRequest):
347
 
348
  @app.get("/")
349
  async def root():
350
- return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.1"}
351
 
352
- # Performance optimization middleware
353
  @app.middleware("http")
354
  async def add_performance_headers(request, call_next):
355
  start_time = time.time()
356
  response = await call_next(request)
357
  process_time = time.time() - start_time
358
  response.headers["X-Process-Time"] = str(round(process_time, 3))
359
- response.headers["X-API-Version"] = "9.2.1"
360
  return response
 
 
1
  import os
2
  import httpx
3
  import json
 
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.2 (Spacing Fixed)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
 
24
  class ModelList(BaseModel):
25
  object: str = "list"; data: List[ModelCard] = []
26
  class ChatMessage(BaseModel):
27
+ role: Literal["system", "user", "assistant", "tool"]; content: Union[str, List[Dict[str, Any]]]; name: Optional[str] = None; tool_calls: Optional[List[Any]] = None
28
  class FunctionDefinition(BaseModel):
29
  name: str; description: Optional[str] = None; parameters: Optional[Dict[str, Any]] = None
30
  class ToolDefinition(BaseModel):
 
56
  # --- Supported Models ---
57
  SUPPORTED_MODELS = {
58
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
59
+ "claude-3-haiku-20240307": "anthropic/claude-3-haiku-20240307", # Example of another common model
60
+ "claude-3-sonnet-20240229": "anthropic/claude-3-sonnet-20240229",
61
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
62
  }
63
 
64
  # --- Core Logic ---
65
  def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
 
66
  prompt_parts = []
67
  system_prompt = None
68
  image_input = None
69
 
 
70
  if functions:
71
+ functions_text = "You have access to the following tools. Use them if required to answer the user's question.\n\n"
72
  for func in functions:
73
+ functions_text += f"- Function: {func.name}\n"
74
+ if func.description: functions_text += f" Description: {func.description}\n"
75
+ if func.parameters: functions_text += f" Parameters: {json.dumps(func.parameters)}\n"
76
  prompt_parts.append(functions_text)
77
 
78
  for msg in messages:
79
  if msg.role == "system":
80
  system_prompt = str(msg.content)
81
  elif msg.role == "assistant":
82
+ if msg.tool_calls:
 
83
  tool_calls_text = "\nTool calls:\n"
84
  for tool_call in msg.tool_calls:
85
  tool_calls_text += f"- {tool_call.function.name}({tool_call.function.arguments})\n"
 
87
  else:
88
  prompt_parts.append(f"Assistant: {msg.content}")
89
  elif msg.role == "tool":
 
90
  prompt_parts.append(f"Tool Response: {msg.content}")
91
  elif msg.role == "user":
92
  user_text_content = ""
 
101
  user_text_content = str(msg.content)
102
  prompt_parts.append(f"User: {user_text_content}")
103
 
104
+ prompt_parts.append("Assistant:") # Let the model generate the space after this
 
105
  return {
106
  "prompt": "\n\n".join(prompt_parts),
107
  "system_prompt": system_prompt,
 
109
  }
110
 
111
  def parse_function_call(content: str) -> Optional[Dict[str, Any]]:
 
112
  try:
 
113
  if "function_call" in content or ("name" in content and "arguments" in content):
 
114
  start = content.find("{")
115
  end = content.rfind("}") + 1
116
  if start != -1 and end > start:
 
123
  return None
124
 
125
  async def stream_replicate_response(replicate_model_id: str, input_payload: dict, request_id: str):
 
126
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
127
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
128
 
 
141
  return
142
  except httpx.HTTPStatusError as e:
143
  error_details = e.response.text
144
+ try: error_details = e.response.json().get("detail", error_details)
 
 
145
  except json.JSONDecodeError: pass
146
  yield f"data: {json.dumps({'error': {'message': f'Upstream Error: {error_details}', 'type': 'replicate_error'}})}\n\n"
147
  return
 
160
  elif line.startswith("data:") and current_event == "output":
161
  raw_data = line[5:].strip()
162
  if not raw_data: continue
163
+
 
164
  try:
165
  content_token = json.loads(raw_data)
166
  except (json.JSONDecodeError, TypeError):
167
  content_token = raw_data
168
 
169
+ # ### MAJOR FIX HERE ###
170
+ # This logic robustly handles the leading space by only stripping
171
+ # the very first non-empty token of the entire stream.
172
  if first_token:
173
  content_token = content_token.lstrip()
174
+ # Only flip the flag if we've actually processed a token with content.
175
+ if content_token:
176
+ first_token = False
177
 
178
  accumulated_content += content_token
179
  completion_tokens += 1
180
 
 
181
  function_call = parse_function_call(accumulated_content)
182
  if function_call:
183
+ tool_call = ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=function_call["arguments"]))
184
+ chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(tool_calls=[tool_call]), finish_reason=None)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  yield f"data: {chunk.json()}\n\n"
186
  else:
187
+ # Only yield a chunk if there is content to send.
188
+ if content_token:
189
+ chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token), finish_reason=None)])
190
+ yield f"data: {chunk.json()}\n\n"
 
 
 
 
 
 
 
 
191
 
192
  elif current_event == "done":
 
193
  end_time = time.time()
194
+ usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, inference_time=round(end_time - start_time, 3))
195
+ usage_chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(), finish_reason="stop")], usage=usage)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  yield f"data: {usage_chunk.json()}\n\n"
197
  break
198
 
 
212
  if request.model not in SUPPORTED_MODELS:
213
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
214
 
215
+ replicate_model_id = SUPPORTED_MODELS[request.model]
216
  formatted = format_messages_for_replicate(request.messages, request.functions)
217
  replicate_input = {
218
  "prompt": formatted["prompt"],
 
220
  "temperature": request.temperature or 0.7,
221
  "top_p": request.top_p or 1.0
222
  }
223
+ if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
224
+ if formatted["image"]: replicate_input["image"] = formatted["image"]
 
 
225
 
226
  request_id = f"chatcmpl-{int(time.time())}"
227
 
228
  if request.stream:
229
  return StreamingResponse(
230
+ stream_replicate_response(replicate_model_id, replicate_input, request_id),
231
  media_type="text/event-stream"
232
  )
233
 
234
  # Non-streaming response
235
+ url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
236
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
237
  start_time = time.time()
238
 
 
243
  pred = resp.json()
244
  output = "".join(pred.get("output", []))
245
 
246
+ output = output.strip() # Clean up any leading/trailing whitespace
 
247
 
 
248
  end_time = time.time()
 
249
  prompt_tokens = len(replicate_input.get("prompt", "")) // 4
250
  completion_tokens = len(output) // 4
251
 
 
252
  tool_calls = None
253
+ finish_reason = "stop"
254
+ message_content = output
255
+
256
  function_call = parse_function_call(output)
257
  if function_call:
258
+ tool_calls = [ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=function_call["arguments"]))]
259
+ finish_reason = "tool_calls"
260
+ message_content = None # OpenAI standard: content is null when tool_calls are present
 
 
 
 
 
261
 
262
  return ChatCompletion(
263
  id=request_id,
 
265
  model=request.model,
266
  choices=[Choice(
267
  index=0,
268
+ message=ChatMessage(role="assistant", content=message_content, tool_calls=tool_calls),
269
+ finish_reason=finish_reason
 
 
 
 
270
  )],
271
  usage=Usage(
272
  prompt_tokens=prompt_tokens,
273
  completion_tokens=completion_tokens,
274
  total_tokens=prompt_tokens + completion_tokens,
275
+ inference_time=round(end_time - start_time, 3)
276
  )
277
  )
278
  except httpx.HTTPStatusError as e:
 
282
 
283
  @app.get("/")
284
  async def root():
285
+ return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.2"}
286
 
 
287
  @app.middleware("http")
288
  async def add_performance_headers(request, call_next):
289
  start_time = time.time()
290
  response = await call_next(request)
291
  process_time = time.time() - start_time
292
  response.headers["X-Process-Time"] = str(round(process_time, 3))
293
+ response.headers["X-API-Version"] = "9.2.2"
294
  return response