rkihacker commited on
Commit
67e24f8
·
verified ·
1 Parent(s): 5e2bd86

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +66 -66
main.py CHANGED
@@ -22,7 +22,7 @@ if not SERVER_API_KEY:
22
  raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
23
 
24
  # FastAPI Init
25
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.3.0 (Streaming Fix)")
26
 
27
  # --- Authentication ---
28
  security = HTTPBearer()
@@ -52,7 +52,7 @@ class ModelList(BaseModel):
52
 
53
  class ChatMessage(BaseModel):
54
  role: Literal["system", "user", "assistant", "tool"]
55
- content: Union[str, List[Dict[str, Any]], None] # Allow content to be None for tool calls
56
  name: Optional[str] = None
57
  tool_calls: Optional[List[Any]] = None
58
 
@@ -127,37 +127,33 @@ class ChatCompletionChunk(BaseModel):
127
  # --- Supported Models ---
128
  SUPPORTED_MODELS = {
129
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
130
- "claude-4.5-haiku": "anthropic/claude-3-haiku-20240307",
131
- "claude-4.5-sonnet": "anthropic/claude-3.5-sonnet-20240620", # Updated to correct model ID
132
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
133
  }
134
 
135
  # --- Core Logic ---
136
 
137
  def generate_request_id() -> str:
138
- """Generates a unique request ID."""
139
- return f"chatcmpl-{secrets.token_hex(16)}"
140
 
141
  def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
142
  prompt_parts = []
143
  system_prompt = None
144
  image_input = None
145
 
146
- # Handle functions/tools if provided
147
- tools_prompt_section = ""
148
  if functions:
149
- tools_prompt_section += "You have access to the following tools. Use them if required to answer the user's question.\n\n"
150
  for func in functions:
151
- tools_prompt_section += f"- Function: {func.name}\n"
152
- if func.description: tools_prompt_section += f" Description: {func.description}\n"
153
- if func.parameters: tools_prompt_section += f" Parameters: {json.dumps(func.parameters)}\n"
154
- tools_prompt_section += "\nTo call a function, respond with a JSON object like this: {\"name\": \"function_name\", \"arguments\": {\"arg1\": \"value1\"}}\n"
155
 
156
  for msg in messages:
157
  if msg.role == "system":
158
  system_prompt = str(msg.content)
159
- if tools_prompt_section:
160
- system_prompt += "\n\n" + tools_prompt_section
161
  elif msg.role == "assistant":
162
  if msg.tool_calls:
163
  tool_calls_text = "\nTool calls:\n"
@@ -181,12 +177,9 @@ def format_messages_for_replicate(messages: List[ChatMessage], functions: Option
181
  user_text_content = str(msg.content)
182
  prompt_parts.append(f"User: {user_text_content}")
183
 
184
- if not system_prompt and tools_prompt_section:
185
- system_prompt = tools_prompt_section
186
-
187
- prompt_parts.append("Assistant:")
188
  return {
189
- "prompt": "\n".join(prompt_parts),
190
  "system_prompt": system_prompt,
191
  "image": image_input
192
  }
@@ -205,7 +198,7 @@ def parse_function_call(content: str) -> Optional[Dict[str, Any]]:
205
  pass
206
  return None
207
 
208
- async def stream_replicate_response(replicate_model_id: str, input_payload: dict, request_id: str, model_name: str):
209
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
210
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
211
 
@@ -232,6 +225,7 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
232
  try:
233
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
234
  current_event = None
 
235
 
236
  async for line in sse.aiter_lines():
237
  if not line: continue
@@ -243,42 +237,34 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
243
  if not raw_data: continue
244
 
245
  try:
246
- # Replicate sends JSON-encoded strings. This correctly handles escaped chars like \n
247
  content_token = json.loads(raw_data)
248
- if not isinstance(content_token, str):
249
- content_token = str(content_token) # Ensure it's a string
250
  except (json.JSONDecodeError, TypeError):
251
- content_token = raw_data # Fallback for non-JSON data
252
-
253
- completion_tokens += 1
254
 
255
- if content_token:
256
- chunk = ChatCompletionChunk(
257
- id=request_id,
258
- created=int(time.time()),
259
- model=model_name,
260
- choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token))]
261
- )
262
- yield f"data: {chunk.model_dump_json()}\n\n"
263
-
264
- # --- THIS IS THE CRITICAL FIX for incomplete responses ---
265
- elif line.startswith("data:") and current_event == "error":
266
- raw_data = line[5:].strip()
267
- error_details = raw_data
268
- try:
269
- error_json = json.loads(raw_data)
270
- error_details = error_json.get("detail") or str(error_json)
271
- except json.JSONDecodeError: pass
272
 
273
- error_chunk = {"error": {"message": f"Replicate stream error: {error_details}", "type": "replicate_error"}}
274
- yield f"data: {json.dumps(error_chunk)}\n\n"
275
- break # Stop streaming on error
276
 
 
 
 
 
 
 
 
 
 
 
277
  elif current_event == "done":
278
  end_time = time.time()
279
  usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, inference_time=round(end_time - start_time, 3))
280
- usage_chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=model_name, choices=[ChoiceDelta(index=0, delta=DeltaMessage(), finish_reason="stop")], usage=usage)
281
- yield f"data: {usage_chunk.model_dump_json()}\n\n"
282
  break
283
 
284
  except httpx.ReadTimeout:
@@ -290,10 +276,16 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
290
  # --- Endpoints ---
291
  @app.get("/v1/models", dependencies=[Depends(verify_api_key)])
292
  async def list_models():
 
 
 
293
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
294
 
295
  @app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
296
  async def create_chat_completion(request: ChatCompletionRequest):
 
 
 
297
  if request.model not in SUPPORTED_MODELS:
298
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
299
 
@@ -302,17 +294,13 @@ async def create_chat_completion(request: ChatCompletionRequest):
302
 
303
  replicate_input = {
304
  "prompt": formatted["prompt"],
305
- "temperature": request.temperature if request.temperature is not None else 0.7,
306
- "top_p": request.top_p if request.top_p is not None else 1.0,
307
  }
308
 
309
  if request.max_tokens is not None:
310
  replicate_input["max_new_tokens"] = request.max_tokens
311
 
312
- # --- THIS IS THE SECOND FIX for incomplete responses ---
313
- if request.stop:
314
- replicate_input["stop_sequences"] = request.stop if isinstance(request.stop, list) else [request.stop]
315
-
316
  if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
317
  if formatted["image"]: replicate_input["image"] = formatted["image"]
318
 
@@ -320,7 +308,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
320
 
321
  if request.stream:
322
  return StreamingResponse(
323
- stream_replicate_response(replicate_model_id, replicate_input, request_id, request.model),
324
  media_type="text/event-stream"
325
  )
326
 
@@ -335,13 +323,21 @@ async def create_chat_completion(request: ChatCompletionRequest):
335
  resp.raise_for_status()
336
  pred = resp.json()
337
 
338
- # Handle errors in non-streaming mode
339
- if pred.get("status") == "failed":
340
- raise HTTPException(status_code=500, detail=f"Replicate prediction failed: {pred.get('error')}")
341
-
342
  raw_output = pred.get("output")
343
- output = "".join(raw_output) if isinstance(raw_output, list) else (raw_output or "")
344
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  end_time = time.time()
346
  prompt_tokens = len(replicate_input.get("prompt", "")) // 4
347
  completion_tokens = len(output) // 4
@@ -352,9 +348,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
352
 
353
  function_call = parse_function_call(output)
354
  if function_call:
355
- tool_calls = [ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=json.dumps(function_call["arguments"])))]
356
  finish_reason = "tool_calls"
357
- message_content = None
358
 
359
  return ChatCompletion(
360
  id=request_id,
@@ -379,7 +375,10 @@ async def create_chat_completion(request: ChatCompletionRequest):
379
 
380
  @app.get("/")
381
  async def root():
382
- return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.3.0"}
 
 
 
383
 
384
  @app.middleware("http")
385
  async def add_performance_headers(request, call_next):
@@ -387,5 +386,6 @@ async def add_performance_headers(request, call_next):
387
  response = await call_next(request)
388
  process_time = time.time() - start_time
389
  response.headers["X-Process-Time"] = str(round(process_time, 3))
390
- response.headers["X-API-Version"] = "9.3.0"
391
- return response
 
 
22
  raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")
23
 
24
  # FastAPI Init
25
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.8 (Raw Output Fix)")
26
 
27
  # --- Authentication ---
28
  security = HTTPBearer()
 
52
 
53
  class ChatMessage(BaseModel):
54
  role: Literal["system", "user", "assistant", "tool"]
55
+ content: Union[str, List[Dict[str, Any]]]
56
  name: Optional[str] = None
57
  tool_calls: Optional[List[Any]] = None
58
 
 
127
  # --- Supported Models ---
128
  SUPPORTED_MODELS = {
129
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
130
+ "claude-4.5-haiku": "anthropic/claude-4.5-haiku",
131
+ "claude-4.5-sonnet": "anthropic/claude-4.5-sonnet",
132
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
133
  }
134
 
135
  # --- Core Logic ---
136
 
137
  def generate_request_id() -> str:
138
+ """Generates a unique request ID in the user-specified format."""
139
+ return f"gen-{int(time.time())}-{secrets.token_hex(8)}"
140
 
141
  def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
142
  prompt_parts = []
143
  system_prompt = None
144
  image_input = None
145
 
 
 
146
  if functions:
147
+ functions_text = "You have access to the following tools. Use them if required to answer the user's question.\n\n"
148
  for func in functions:
149
+ functions_text += f"- Function: {func.name}\n"
150
+ if func.description: functions_text += f" Description: {func.description}\n"
151
+ if func.parameters: functions_text += f" Parameters: {json.dumps(func.parameters)}\n"
152
+ prompt_parts.append(functions_text)
153
 
154
  for msg in messages:
155
  if msg.role == "system":
156
  system_prompt = str(msg.content)
 
 
157
  elif msg.role == "assistant":
158
  if msg.tool_calls:
159
  tool_calls_text = "\nTool calls:\n"
 
177
  user_text_content = str(msg.content)
178
  prompt_parts.append(f"User: {user_text_content}")
179
 
180
+ prompt_parts.append("Assistant:") # Let the model generate the space after this
 
 
 
181
  return {
182
+ "prompt": "\n\n".join(prompt_parts),
183
  "system_prompt": system_prompt,
184
  "image": image_input
185
  }
 
198
  pass
199
  return None
200
 
201
+ async def stream_replicate_response(replicate_model_id: str, input_payload: dict, request_id: str):
202
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
203
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
204
 
 
225
  try:
226
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
227
  current_event = None
228
+ accumulated_content = ""
229
 
230
  async for line in sse.aiter_lines():
231
  if not line: continue
 
237
  if not raw_data: continue
238
 
239
  try:
 
240
  content_token = json.loads(raw_data)
 
 
241
  except (json.JSONDecodeError, TypeError):
242
+ content_token = raw_data
 
 
243
 
244
+ # ### THIS IS THE FIX ###
245
+ # There is NO lstrip() or strip() here.
246
+ # This sends the raw, unmodified token from Replicate.
247
+ # If the log shows "HowcanI", it's because the model
248
+ # sent "How", "can", "I" as separate tokens.
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
+ accumulated_content += content_token
251
+ completion_tokens += 1
 
252
 
253
+ function_call = parse_function_call(accumulated_content)
254
+ if function_call:
255
+ tool_call = ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=function_call["arguments"]))
256
+ chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(tool_calls=[tool_call]), finish_reason=None)])
257
+ yield f"data: {chunk.json()}\n\n"
258
+ else:
259
+ if content_token:
260
+ chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token), finish_reason=None)])
261
+ yield f"data: {chunk.json()}\n\n"
262
+
263
  elif current_event == "done":
264
  end_time = time.time()
265
  usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, inference_time=round(end_time - start_time, 3))
266
+ usage_chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(), finish_reason="stop")], usage=usage)
267
+ yield f"data: {usage_chunk.json()}\n\n"
268
  break
269
 
270
  except httpx.ReadTimeout:
 
276
  # --- Endpoints ---
277
  @app.get("/v1/models", dependencies=[Depends(verify_api_key)])
278
  async def list_models():
279
+ """
280
+ Protected endpoint to list available models.
281
+ """
282
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
283
 
284
  @app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
285
  async def create_chat_completion(request: ChatCompletionRequest):
286
+ """
287
+ Protected endpoint to create a chat completion.
288
+ """
289
  if request.model not in SUPPORTED_MODELS:
290
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
291
 
 
294
 
295
  replicate_input = {
296
  "prompt": formatted["prompt"],
297
+ "temperature": request.temperature or 0.7,
298
+ "top_p": request.top_p or 1.0
299
  }
300
 
301
  if request.max_tokens is not None:
302
  replicate_input["max_new_tokens"] = request.max_tokens
303
 
 
 
 
 
304
  if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
305
  if formatted["image"]: replicate_input["image"] = formatted["image"]
306
 
 
308
 
309
  if request.stream:
310
  return StreamingResponse(
311
+ stream_replicate_response(replicate_model_id, replicate_input, request_id),
312
  media_type="text/event-stream"
313
  )
314
 
 
323
  resp.raise_for_status()
324
  pred = resp.json()
325
 
326
+ # Handle the 'output' field which could be a list, string, or null
 
 
 
327
  raw_output = pred.get("output")
 
328
 
329
+ if isinstance(raw_output, list):
330
+ output = "".join(raw_output) # Expected case: list of strings
331
+ elif isinstance(raw_output, str):
332
+ output = raw_output # Handle if it's just a single string
333
+ else:
334
+ output = ""
335
+
336
+ # ### THIS IS THE FIX ###
337
+ # Removed output.strip() to return the raw response.
338
+ # This fixes the bug where a single space (" ") response
339
+ # would become "" and show content: "" in the JSON.
340
+
341
  end_time = time.time()
342
  prompt_tokens = len(replicate_input.get("prompt", "")) // 4
343
  completion_tokens = len(output) // 4
 
348
 
349
  function_call = parse_function_call(output)
350
  if function_call:
351
+ tool_calls = [ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=function_call["arguments"]))]
352
  finish_reason = "tool_calls"
353
+ message_content = None # OpenAI standard: content is null when tool_calls are present
354
 
355
  return ChatCompletion(
356
  id=request_id,
 
375
 
376
  @app.get("/")
377
  async def root():
378
+ """
379
+ Root endpoint for health checks. Does not require authentication.
380
+ """
381
+ return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.8"}
382
 
383
  @app.middleware("http")
384
  async def add_performance_headers(request, call_next):
 
386
  response = await call_next(request)
387
  process_time = time.time() - start_time
388
  response.headers["X-Process-Time"] = str(round(process_time, 3))
389
+ response.headers["X-API-Version"] = "9.2.8"
390
+ return response
391
+