rkihacker commited on
Commit
5e2bd86
·
verified ·
1 Parent(s): 3634096

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +65 -64
main.py CHANGED
@@ -22,7 +22,7 @@ if not SERVER_API_KEY:
22
  raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
23
 
24
  # FastAPI Init
25
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.8 (Raw Output Fix)")
26
 
27
  # --- Authentication ---
28
  security = HTTPBearer()
@@ -52,7 +52,7 @@ class ModelList(BaseModel):
52
 
53
  class ChatMessage(BaseModel):
54
  role: Literal["system", "user", "assistant", "tool"]
55
- content: Union[str, List[Dict[str, Any]]]
56
  name: Optional[str] = None
57
  tool_calls: Optional[List[Any]] = None
58
 
@@ -127,33 +127,37 @@ class ChatCompletionChunk(BaseModel):
127
  # --- Supported Models ---
128
  SUPPORTED_MODELS = {
129
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
130
- "claude-4.5-haiku": "anthropic/claude-4.5-haiku",
131
- "claude-4.5-sonnet": "anthropic/claude-4.5-sonnet",
132
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
133
  }
134
 
135
  # --- Core Logic ---
136
 
137
  def generate_request_id() -> str:
138
- """Generates a unique request ID in the user-specified format."""
139
- return f"gen-{int(time.time())}-{secrets.token_hex(8)}"
140
 
141
  def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
142
  prompt_parts = []
143
  system_prompt = None
144
  image_input = None
145
 
 
 
146
  if functions:
147
- functions_text = "You have access to the following tools. Use them if required to answer the user's question.\n\n"
148
  for func in functions:
149
- functions_text += f"- Function: {func.name}\n"
150
- if func.description: functions_text += f" Description: {func.description}\n"
151
- if func.parameters: functions_text += f" Parameters: {json.dumps(func.parameters)}\n"
152
- prompt_parts.append(functions_text)
153
 
154
  for msg in messages:
155
  if msg.role == "system":
156
  system_prompt = str(msg.content)
 
 
157
  elif msg.role == "assistant":
158
  if msg.tool_calls:
159
  tool_calls_text = "\nTool calls:\n"
@@ -177,9 +181,12 @@ def format_messages_for_replicate(messages: List[ChatMessage], functions: Option
177
  user_text_content = str(msg.content)
178
  prompt_parts.append(f"User: {user_text_content}")
179
 
180
- prompt_parts.append("Assistant:") # Let the model generate the space after this
 
 
 
181
  return {
182
- "prompt": "\n\n".join(prompt_parts),
183
  "system_prompt": system_prompt,
184
  "image": image_input
185
  }
@@ -198,7 +205,7 @@ def parse_function_call(content: str) -> Optional[Dict[str, Any]]:
198
  pass
199
  return None
200
 
201
- async def stream_replicate_response(replicate_model_id: str, input_payload: dict, request_id: str):
202
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
203
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
204
 
@@ -225,7 +232,6 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
225
  try:
226
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
227
  current_event = None
228
- accumulated_content = ""
229
 
230
  async for line in sse.aiter_lines():
231
  if not line: continue
@@ -237,34 +243,42 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
237
  if not raw_data: continue
238
 
239
  try:
 
240
  content_token = json.loads(raw_data)
 
 
241
  except (json.JSONDecodeError, TypeError):
242
- content_token = raw_data
243
-
244
- # ### THIS IS THE FIX ###
245
- # There is NO lstrip() or strip() here.
246
- # This sends the raw, unmodified token from Replicate.
247
- # If the log shows "HowcanI", it's because the model
248
- # sent "How", "can", "I" as separate tokens.
249
-
250
- accumulated_content += content_token
251
  completion_tokens += 1
252
 
253
- function_call = parse_function_call(accumulated_content)
254
- if function_call:
255
- tool_call = ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=function_call["arguments"]))
256
- chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(tool_calls=[tool_call]), finish_reason=None)])
257
- yield f"data: {chunk.json()}\n\n"
258
- else:
259
- if content_token:
260
- chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token), finish_reason=None)])
261
- yield f"data: {chunk.json()}\n\n"
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  elif current_event == "done":
264
  end_time = time.time()
265
  usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, inference_time=round(end_time - start_time, 3))
266
- usage_chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(), finish_reason="stop")], usage=usage)
267
- yield f"data: {usage_chunk.json()}\n\n"
268
  break
269
 
270
  except httpx.ReadTimeout:
@@ -276,16 +290,10 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
276
  # --- Endpoints ---
277
  @app.get("/v1/models", dependencies=[Depends(verify_api_key)])
278
  async def list_models():
279
- """
280
- Protected endpoint to list available models.
281
- """
282
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
283
 
284
  @app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
285
  async def create_chat_completion(request: ChatCompletionRequest):
286
- """
287
- Protected endpoint to create a chat completion.
288
- """
289
  if request.model not in SUPPORTED_MODELS:
290
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
291
 
@@ -294,13 +302,17 @@ async def create_chat_completion(request: ChatCompletionRequest):
294
 
295
  replicate_input = {
296
  "prompt": formatted["prompt"],
297
- "temperature": request.temperature or 0.7,
298
- "top_p": request.top_p or 1.0
299
  }
300
 
301
  if request.max_tokens is not None:
302
  replicate_input["max_new_tokens"] = request.max_tokens
303
 
 
 
 
 
304
  if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
305
  if formatted["image"]: replicate_input["image"] = formatted["image"]
306
 
@@ -308,7 +320,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
308
 
309
  if request.stream:
310
  return StreamingResponse(
311
- stream_replicate_response(replicate_model_id, replicate_input, request_id),
312
  media_type="text/event-stream"
313
  )
314
 
@@ -323,21 +335,13 @@ async def create_chat_completion(request: ChatCompletionRequest):
323
  resp.raise_for_status()
324
  pred = resp.json()
325
 
326
- # Handle the 'output' field which could be a list, string, or null
 
 
 
327
  raw_output = pred.get("output")
 
328
 
329
- if isinstance(raw_output, list):
330
- output = "".join(raw_output) # Expected case: list of strings
331
- elif isinstance(raw_output, str):
332
- output = raw_output # Handle if it's just a single string
333
- else:
334
- output = ""
335
-
336
- # ### THIS IS THE FIX ###
337
- # Removed output.strip() to return the raw response.
338
- # This fixes the bug where a single space (" ") response
339
- # would become "" and show content: "" in the JSON.
340
-
341
  end_time = time.time()
342
  prompt_tokens = len(replicate_input.get("prompt", "")) // 4
343
  completion_tokens = len(output) // 4
@@ -348,9 +352,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
348
 
349
  function_call = parse_function_call(output)
350
  if function_call:
351
- tool_calls = [ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=function_call["arguments"]))]
352
  finish_reason = "tool_calls"
353
- message_content = None # OpenAI standard: content is null when tool_calls are present
354
 
355
  return ChatCompletion(
356
  id=request_id,
@@ -375,10 +379,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
375
 
376
  @app.get("/")
377
  async def root():
378
- """
379
- Root endpoint for health checks. Does not require authentication.
380
- """
381
- return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.8"}
382
 
383
  @app.middleware("http")
384
  async def add_performance_headers(request, call_next):
@@ -386,5 +387,5 @@ async def add_performance_headers(request, call_next):
386
  response = await call_next(request)
387
  process_time = time.time() - start_time
388
  response.headers["X-Process-Time"] = str(round(process_time, 3))
389
- response.headers["X-API-Version"] = "9.2.8"
390
  return response
 
22
  raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
23
 
24
  # FastAPI Init
25
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.3.0 (Streaming Fix)")
26
 
27
  # --- Authentication ---
28
  security = HTTPBearer()
 
52
 
53
  class ChatMessage(BaseModel):
54
  role: Literal["system", "user", "assistant", "tool"]
55
+ content: Union[str, List[Dict[str, Any]], None] # Allow content to be None for tool calls
56
  name: Optional[str] = None
57
  tool_calls: Optional[List[Any]] = None
58
 
 
127
  # --- Supported Models ---
128
  SUPPORTED_MODELS = {
129
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
130
+ "claude-4.5-haiku": "anthropic/claude-3-haiku-20240307",
131
+ "claude-4.5-sonnet": "anthropic/claude-3.5-sonnet-20240620", # Updated to correct model ID
132
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
133
  }
134
 
135
  # --- Core Logic ---
136
 
137
  def generate_request_id() -> str:
138
+ """Generates a unique request ID."""
139
+ return f"chatcmpl-{secrets.token_hex(16)}"
140
 
141
  def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
142
  prompt_parts = []
143
  system_prompt = None
144
  image_input = None
145
 
146
+ # Handle functions/tools if provided
147
+ tools_prompt_section = ""
148
  if functions:
149
+ tools_prompt_section += "You have access to the following tools. Use them if required to answer the user's question.\n\n"
150
  for func in functions:
151
+ tools_prompt_section += f"- Function: {func.name}\n"
152
+ if func.description: tools_prompt_section += f" Description: {func.description}\n"
153
+ if func.parameters: tools_prompt_section += f" Parameters: {json.dumps(func.parameters)}\n"
154
+ tools_prompt_section += "\nTo call a function, respond with a JSON object like this: {\"name\": \"function_name\", \"arguments\": {\"arg1\": \"value1\"}}\n"
155
 
156
  for msg in messages:
157
  if msg.role == "system":
158
  system_prompt = str(msg.content)
159
+ if tools_prompt_section:
160
+ system_prompt += "\n\n" + tools_prompt_section
161
  elif msg.role == "assistant":
162
  if msg.tool_calls:
163
  tool_calls_text = "\nTool calls:\n"
 
181
  user_text_content = str(msg.content)
182
  prompt_parts.append(f"User: {user_text_content}")
183
 
184
+ if not system_prompt and tools_prompt_section:
185
+ system_prompt = tools_prompt_section
186
+
187
+ prompt_parts.append("Assistant:")
188
  return {
189
+ "prompt": "\n".join(prompt_parts),
190
  "system_prompt": system_prompt,
191
  "image": image_input
192
  }
 
205
  pass
206
  return None
207
 
208
+ async def stream_replicate_response(replicate_model_id: str, input_payload: dict, request_id: str, model_name: str):
209
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
210
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
211
 
 
232
  try:
233
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
234
  current_event = None
 
235
 
236
  async for line in sse.aiter_lines():
237
  if not line: continue
 
243
  if not raw_data: continue
244
 
245
  try:
246
+ # Replicate sends JSON-encoded strings. This correctly handles escaped chars like \n
247
  content_token = json.loads(raw_data)
248
+ if not isinstance(content_token, str):
249
+ content_token = str(content_token) # Ensure it's a string
250
  except (json.JSONDecodeError, TypeError):
251
+ content_token = raw_data # Fallback for non-JSON data
252
+
 
 
 
 
 
 
 
253
  completion_tokens += 1
254
 
255
+ if content_token:
256
+ chunk = ChatCompletionChunk(
257
+ id=request_id,
258
+ created=int(time.time()),
259
+ model=model_name,
260
+ choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token))]
261
+ )
262
+ yield f"data: {chunk.model_dump_json()}\n\n"
 
263
 
264
+ # --- THIS IS THE CRITICAL FIX for incomplete responses ---
265
+ elif line.startswith("data:") and current_event == "error":
266
+ raw_data = line[5:].strip()
267
+ error_details = raw_data
268
+ try:
269
+ error_json = json.loads(raw_data)
270
+ error_details = error_json.get("detail") or str(error_json)
271
+ except json.JSONDecodeError: pass
272
+
273
+ error_chunk = {"error": {"message": f"Replicate stream error: {error_details}", "type": "replicate_error"}}
274
+ yield f"data: {json.dumps(error_chunk)}\n\n"
275
+ break # Stop streaming on error
276
+
277
  elif current_event == "done":
278
  end_time = time.time()
279
  usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, inference_time=round(end_time - start_time, 3))
280
+ usage_chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=model_name, choices=[ChoiceDelta(index=0, delta=DeltaMessage(), finish_reason="stop")], usage=usage)
281
+ yield f"data: {usage_chunk.model_dump_json()}\n\n"
282
  break
283
 
284
  except httpx.ReadTimeout:
 
290
  # --- Endpoints ---
291
  @app.get("/v1/models", dependencies=[Depends(verify_api_key)])
292
  async def list_models():
 
 
 
293
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
294
 
295
  @app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
296
  async def create_chat_completion(request: ChatCompletionRequest):
 
 
 
297
  if request.model not in SUPPORTED_MODELS:
298
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
299
 
 
302
 
303
  replicate_input = {
304
  "prompt": formatted["prompt"],
305
+ "temperature": request.temperature if request.temperature is not None else 0.7,
306
+ "top_p": request.top_p if request.top_p is not None else 1.0,
307
  }
308
 
309
  if request.max_tokens is not None:
310
  replicate_input["max_new_tokens"] = request.max_tokens
311
 
312
+ # --- THIS IS THE SECOND FIX for incomplete responses ---
313
+ if request.stop:
314
+ replicate_input["stop_sequences"] = request.stop if isinstance(request.stop, list) else [request.stop]
315
+
316
  if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
317
  if formatted["image"]: replicate_input["image"] = formatted["image"]
318
 
 
320
 
321
  if request.stream:
322
  return StreamingResponse(
323
+ stream_replicate_response(replicate_model_id, replicate_input, request_id, request.model),
324
  media_type="text/event-stream"
325
  )
326
 
 
335
  resp.raise_for_status()
336
  pred = resp.json()
337
 
338
+ # Handle errors in non-streaming mode
339
+ if pred.get("status") == "failed":
340
+ raise HTTPException(status_code=500, detail=f"Replicate prediction failed: {pred.get('error')}")
341
+
342
  raw_output = pred.get("output")
343
+ output = "".join(raw_output) if isinstance(raw_output, list) else (raw_output or "")
344
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  end_time = time.time()
346
  prompt_tokens = len(replicate_input.get("prompt", "")) // 4
347
  completion_tokens = len(output) // 4
 
352
 
353
  function_call = parse_function_call(output)
354
  if function_call:
355
+ tool_calls = [ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=json.dumps(function_call["arguments"])))]
356
  finish_reason = "tool_calls"
357
+ message_content = None
358
 
359
  return ChatCompletion(
360
  id=request_id,
 
379
 
380
  @app.get("/")
381
  async def root():
382
+ return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.3.0"}
 
 
 
383
 
384
  @app.middleware("http")
385
  async def add_performance_headers(request, call_next):
 
387
  response = await call_next(request)
388
  process_time = time.time() - start_time
389
  response.headers["X-Process-Time"] = str(round(process_time, 3))
390
+ response.headers["X-API-Version"] = "9.3.0"
391
  return response