rkihacker commited on
Commit
a9bb1ec
·
verified ·
1 Parent(s): 2767573

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +235 -148
main.py CHANGED
@@ -3,12 +3,12 @@ import os
3
  import httpx
4
  import json
5
  import time
 
6
  from fastapi import FastAPI, HTTPException
7
- from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel, Field
9
  from typing import List, Dict, Any, Optional, Union, Literal
10
  from dotenv import load_dotenv
11
- from sse_starlette.sse import EventSourceResponse
12
 
13
  # Load environment variables
14
  load_dotenv()
@@ -17,7 +17,7 @@ if not REPLICATE_API_TOKEN:
17
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
18
 
19
  # FastAPI Init
20
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.1.0 (Enhanced Token Tracking)")
21
 
22
  # --- Pydantic Models ---
23
  class ModelCard(BaseModel):
@@ -25,9 +25,34 @@ class ModelCard(BaseModel):
25
  class ModelList(BaseModel):
26
  object: str = "list"; data: List[ModelCard] = []
27
  class ChatMessage(BaseModel):
28
- role: Literal["system", "user", "assistant", "tool"]; content: Union[str, List[Dict[str, Any]]]
29
- class OpenAIChatCompletionRequest(BaseModel):
30
- model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # --- Supported Models ---
33
  SUPPORTED_MODELS = {
@@ -38,21 +63,36 @@ SUPPORTED_MODELS = {
38
  }
39
 
40
  # --- Core Logic ---
41
- def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
42
- """
43
- Formats the input for Replicate's API, flattening the message history into a
44
- single 'prompt' string and handling images separately.
45
- """
46
- payload = {}
47
  prompt_parts = []
48
  system_prompt = None
49
  image_input = None
50
 
51
- for msg in request.messages:
 
 
 
 
 
 
 
 
 
52
  if msg.role == "system":
53
  system_prompt = str(msg.content)
54
  elif msg.role == "assistant":
55
- prompt_parts.append(f"Assistant: {msg.content}")
 
 
 
 
 
 
 
 
 
 
56
  elif msg.role == "user":
57
  user_text_content = ""
58
  if isinstance(msg.content, list):
@@ -67,37 +107,46 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
67
  prompt_parts.append(f"User: {user_text_content}")
68
 
69
  prompt_parts.append("Assistant:")
70
- payload["prompt"] = "\n\n".join(prompt_parts)
71
-
72
- if system_prompt:
73
- payload["system_prompt"] = system_prompt
74
- if image_input:
75
- payload["image"] = image_input
76
-
77
- if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
78
- if request.temperature: payload["temperature"] = request.temperature
79
- if request.top_p: payload["top_p"] = request.top_p
80
-
81
- return payload
 
 
 
 
 
 
 
 
 
 
82
 
83
- async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
84
- """Handles the full streaming lifecycle with enhanced token tracking and timing."""
85
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
86
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
87
 
88
  start_time = time.time()
89
- prompt_tokens = len(input_payload.get("prompt", "")) // 4 # Rough estimation
90
  completion_tokens = 0
91
 
92
- async with httpx.AsyncClient(timeout=60.0) as client:
93
  try:
94
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
95
  response.raise_for_status()
96
  prediction = response.json()
97
  stream_url = prediction.get("urls", {}).get("stream")
98
- prediction_id = prediction.get("id", "stream-unknown")
99
  if not stream_url:
100
- yield json.dumps({'error': {'message': 'Model did not return a stream URL.'}})
101
  return
102
  except httpx.HTTPStatusError as e:
103
  error_details = e.response.text
@@ -105,116 +154,99 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
105
  error_json = e.response.json()
106
  error_details = error_json.get("detail", error_details)
107
  except json.JSONDecodeError: pass
108
- yield json.dumps({'error': {'message': f'Upstream Error: {error_details}', 'type': 'replicate_error'}})
109
  return
110
 
111
  try:
112
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
113
  current_event = None
 
 
114
  async for line in sse.aiter_lines():
115
- if not line: # Skip empty lines
116
- continue
117
  if line.startswith("event:"):
118
  current_event = line[len("event:"):].strip()
119
- elif line.startswith("data:"):
120
- # Remove "data:" prefix and optional space
121
- raw_data = line[5:] # Remove "data:"
122
- if raw_data.startswith(" "):
123
- data_content = raw_data[1:] # Remove the first space only
124
- else:
125
- data_content = raw_data
126
 
127
- if current_event == "output":
128
- if not data_content:
129
- continue
130
-
131
- content_token = ""
132
- try:
133
- # Handle JSON-encoded strings properly
134
- content_token = json.loads(data_content)
135
- except (json.JSONDecodeError, TypeError):
136
- # Handle plain text tokens
137
- content_token = data_content
138
-
139
- completion_tokens += 1
140
- chunk = {
141
- "choices": [{
142
- "delta": {"content": content_token},
143
- "finish_reason": None,
144
- "index": 0,
145
- "logprobs": None,
146
- "native_finish_reason": None
147
- }],
148
- "created": int(time.time()),
149
- "id": f"gen-{int(time.time())}-{prediction_id[-12:]}",
150
- "model": replicate_model_id,
151
- "object": "chat.completion.chunk",
152
- "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate"
153
- }
154
- yield json.dumps(chunk)
155
-
156
- elif current_event == "done":
157
- # Calculate timing
158
- end_time = time.time()
159
- inference_time = end_time - start_time
160
-
161
- # Send usage chunk before done
162
- usage_chunk = {
163
- "choices": [{
164
- "delta": {},
165
- "finish_reason": None,
166
- "index": 0,
167
- "logprobs": None,
168
- "native_finish_reason": None
169
- }],
170
- "created": int(time.time()),
171
- "id": f"gen-{int(time.time())}-{prediction_id[-12:]}",
172
- "model": replicate_model_id,
173
- "object": "chat.completion.chunk",
174
- "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate",
175
- "usage": {
176
- "cache_discount": 0,
177
- "completion_tokens": completion_tokens,
178
- "completion_tokens_details": {"image_tokens": 0, "reasoning_tokens": 0},
179
- "cost": 0,
180
- "cost_details": {
181
- "upstream_inference_completions_cost": 0,
182
- "upstream_inference_cost": None,
183
- "upstream_inference_prompt_cost": 0
184
- },
185
- "input_tokens": prompt_tokens,
186
- "is_byok": False,
187
- "prompt_tokens": prompt_tokens,
188
- "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
189
- "total_tokens": prompt_tokens + completion_tokens,
190
- "inference_time": round(inference_time, 3)
191
- }
192
- }
193
- yield json.dumps(usage_chunk)
194
 
195
- # Send final chunk with stop reason
196
- final_chunk = {
197
- "choices": [{
198
- "delta": {},
199
- "finish_reason": "stop",
200
- "index": 0,
201
- "logprobs": None,
202
- "native_finish_reason": "end_turn"
203
- }],
204
- "created": int(time.time()),
205
- "id": f"gen-{int(time.time())}-{prediction_id[-12:]}",
206
- "model": replicate_model_id,
207
- "object": "chat.completion.chunk",
208
- "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate"
209
- }
210
- yield json.dumps(final_chunk)
211
- break
 
 
 
 
 
 
 
 
 
212
  except httpx.ReadTimeout:
213
- yield json.dumps({'error': {'message': 'Stream timed out.', 'type': 'timeout_error'}})
214
  return
215
 
216
- # Send [DONE] event
217
- yield "[DONE]"
218
 
219
  # --- Endpoints ---
220
  @app.get("/v1/models")
@@ -222,23 +254,39 @@ async def list_models():
222
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
223
 
224
  @app.post("/v1/chat/completions")
225
- async def create_chat_completion(request: OpenAIChatCompletionRequest):
226
  if request.model not in SUPPORTED_MODELS:
227
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
228
 
229
- replicate_input = prepare_replicate_input(request)
230
-
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  if request.stream:
232
- return EventSourceResponse(stream_replicate_sse(SUPPORTED_MODELS[request.model], replicate_input), media_type="text/event-stream")
 
 
 
233
 
234
- # Non-streaming fallback with usage data
235
  url = f"https://api.replicate.com/v1/models/{SUPPORTED_MODELS[request.model]}/predictions"
236
- headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
237
  start_time = time.time()
238
 
239
  async with httpx.AsyncClient() as client:
240
  try:
241
- resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=130.0)
242
  resp.raise_for_status()
243
  pred = resp.json()
244
  output = "".join(pred.get("output", []))
@@ -246,18 +294,57 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
246
  # Calculate timing and tokens
247
  end_time = time.time()
248
  inference_time = end_time - start_time
249
- prompt_tokens = len(input_payload.get("prompt", "")) // 4 # Rough estimation
250
- completion_tokens = len(output) // 4 # Rough estimation
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- return {
253
- "id": pred.get("id"), "object": "chat.completion", "created": int(time.time()), "model": request.model,
254
- "choices": [{"index": 0, "message": {"role": "assistant", "content": output}, "finish_reason": "stop"}],
255
- "usage": {
256
- "prompt_tokens": prompt_tokens,
257
- "completion_tokens": completion_tokens,
258
- "total_tokens": prompt_tokens + completion_tokens,
259
- "inference_time": round(inference_time, 3)
260
- }
261
- }
 
 
 
 
 
 
 
 
 
 
262
  except httpx.HTTPStatusError as e:
263
  raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import httpx
4
  import json
5
  import time
6
+ import asyncio
7
  from fastapi import FastAPI, HTTPException
8
+ from fastapi.responses import StreamingResponse, JSONResponse
9
  from pydantic import BaseModel, Field
10
  from typing import List, Dict, Any, Optional, Union, Literal
11
  from dotenv import load_dotenv
 
12
 
13
  # Load environment variables
14
  load_dotenv()
 
17
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
18
 
19
  # FastAPI Init
20
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.0 (Full OpenAI Compatibility)")
21
 
22
  # --- Pydantic Models ---
23
  class ModelCard(BaseModel):
 
25
  class ModelList(BaseModel):
26
  object: str = "list"; data: List[ModelCard] = []
27
  class ChatMessage(BaseModel):
28
+ role: Literal["system", "user", "assistant", "tool"]; content: Union[str, List[Dict[str, Any]]]; name: Optional[str] = None
29
+ class FunctionDefinition(BaseModel):
30
+ name: str; description: Optional[str] = None; parameters: Optional[Dict[str, Any]] = None
31
+ class ToolDefinition(BaseModel):
32
+ type: Literal["function"]; function: FunctionDefinition
33
+ class FunctionCall(BaseModel):
34
+ name: str; arguments: str
35
+ class ToolCall(BaseModel):
36
+ id: str; type: Literal["function"] = "function"; function: FunctionCall
37
+ class ChatCompletionRequest(BaseModel):
38
+ model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0
39
+ max_tokens: Optional[int] = None; stream: Optional[bool] = False; stop: Optional[Union[str, List[str]]] = None
40
+ tools: Optional[List[ToolDefinition]] = None; tool_choice: Optional[Union[str, Dict[str, Any]]] = None
41
+ functions: Optional[List[FunctionDefinition]] = None; function_call: Optional[Union[str, Dict[str, str]]] = None
42
+
43
+ class Choice(BaseModel):
44
+ index: int; message: ChatMessage; finish_reason: Optional[str] = None
45
+ class Usage(BaseModel):
46
+ prompt_tokens: int; completion_tokens: int; total_tokens: int; inference_time: Optional[float] = None
47
+ class ChatCompletion(BaseModel):
48
+ id: str; object: str = "chat.completion"; created: int; model: str; choices: List[Choice]; usage: Usage
49
+
50
+ class DeltaMessage(BaseModel):
51
+ role: Optional[str] = None; content: Optional[str] = None; tool_calls: Optional[List[ToolCall]] = None
52
+ class ChoiceDelta(BaseModel):
53
+ index: int; delta: DeltaMessage; finish_reason: Optional[str] = None
54
+ class ChatCompletionChunk(BaseModel):
55
+ id: str; object: str = "chat.completion.chunk"; created: int; model: str; choices: List[ChoiceDelta]; usage: Optional[Usage] = None
56
 
57
  # --- Supported Models ---
58
  SUPPORTED_MODELS = {
 
63
  }
64
 
65
  # --- Core Logic ---
66
+ def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
67
+ """Convert OpenAI messages to Replicate-compatible format with function calling support."""
 
 
 
 
68
  prompt_parts = []
69
  system_prompt = None
70
  image_input = None
71
 
72
+ # Add functions to system prompt if provided
73
+ if functions:
74
+ functions_text = "\n\nAvailable functions:\n"
75
+ for func in functions:
76
+ functions_text += f"- {func.name}: {func.description or 'No description'}\n"
77
+ if func.parameters:
78
+ functions_text += f" Parameters: {json.dumps(func.parameters)}\n"
79
+ prompt_parts.append(functions_text)
80
+
81
+ for msg in messages:
82
  if msg.role == "system":
83
  system_prompt = str(msg.content)
84
  elif msg.role == "assistant":
85
+ # Handle tool calls in assistant messages
86
+ if hasattr(msg, 'tool_calls') and msg.tool_calls:
87
+ tool_calls_text = "\nTool calls:\n"
88
+ for tool_call in msg.tool_calls:
89
+ tool_calls_text += f"- {tool_call.function.name}({tool_call.function.arguments})\n"
90
+ prompt_parts.append(f"Assistant: {tool_calls_text}")
91
+ else:
92
+ prompt_parts.append(f"Assistant: {msg.content}")
93
+ elif msg.role == "tool":
94
+ # Handle tool responses
95
+ prompt_parts.append(f"Tool Response: {msg.content}")
96
  elif msg.role == "user":
97
  user_text_content = ""
98
  if isinstance(msg.content, list):
 
107
  prompt_parts.append(f"User: {user_text_content}")
108
 
109
  prompt_parts.append("Assistant:")
110
+ return {
111
+ "prompt": "\n\n".join(prompt_parts),
112
+ "system_prompt": system_prompt,
113
+ "image": image_input
114
+ }
115
+
116
+ def parse_function_call(content: str) -> Optional[Dict[str, Any]]:
117
+ """Parse function call from model response."""
118
+ try:
119
+ # Look for JSON-like function call patterns
120
+ if "function_call" in content or ("name" in content and "arguments" in content):
121
+ # Extract JSON part
122
+ start = content.find("{")
123
+ end = content.rfind("}") + 1
124
+ if start != -1 and end > start:
125
+ json_str = content[start:end]
126
+ parsed = json.loads(json_str)
127
+ if "name" in parsed and "arguments" in parsed:
128
+ return parsed
129
+ except (json.JSONDecodeError, Exception):
130
+ pass
131
+ return None
132
 
133
+ async def stream_replicate_response(replicate_model_id: str, input_payload: dict, request_id: str):
134
+ """Stream response with full OpenAI compatibility including tool calls."""
135
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
136
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
137
 
138
  start_time = time.time()
139
+ prompt_tokens = len(input_payload.get("prompt", "")) // 4
140
  completion_tokens = 0
141
 
142
+ async with httpx.AsyncClient(timeout=300.0) as client:
143
  try:
144
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
145
  response.raise_for_status()
146
  prediction = response.json()
147
  stream_url = prediction.get("urls", {}).get("stream")
 
148
  if not stream_url:
149
+ yield f"data: {json.dumps({'error': {'message': 'Model did not return a stream URL.'}})}\n\n"
150
  return
151
  except httpx.HTTPStatusError as e:
152
  error_details = e.response.text
 
154
  error_json = e.response.json()
155
  error_details = error_json.get("detail", error_details)
156
  except json.JSONDecodeError: pass
157
+ yield f"data: {json.dumps({'error': {'message': f'Upstream Error: {error_details}', 'type': 'replicate_error'}})}\n\n"
158
  return
159
 
160
  try:
161
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
162
  current_event = None
163
+ accumulated_content = ""
164
+
165
  async for line in sse.aiter_lines():
166
+ if not line: continue
167
+
168
  if line.startswith("event:"):
169
  current_event = line[len("event:"):].strip()
170
+ elif line.startswith("data:") and current_event == "output":
171
+ raw_data = line[5:].strip()
172
+ if not raw_data: continue
 
 
 
 
173
 
174
+ content_token = ""
175
+ try:
176
+ content_token = json.loads(raw_data)
177
+ except (json.JSONDecodeError, TypeError):
178
+ content_token = raw_data
179
+
180
+ accumulated_content += content_token
181
+ completion_tokens += 1
182
+
183
+ # Check for function calls in accumulated content
184
+ function_call = parse_function_call(accumulated_content)
185
+ if function_call:
186
+ # Send tool call chunk
187
+ tool_call = ToolCall(
188
+ id=f"call_{int(time.time())}",
189
+ function=FunctionCall(
190
+ name=function_call["name"],
191
+ arguments=function_call["arguments"]
192
+ )
193
+ )
194
+ chunk = ChatCompletionChunk(
195
+ id=request_id,
196
+ created=int(time.time()),
197
+ model=replicate_model_id,
198
+ choices=[ChoiceDelta(
199
+ index=0,
200
+ delta=DeltaMessage(tool_calls=[tool_call]),
201
+ finish_reason=None
202
+ )]
203
+ )
204
+ yield f"data: {chunk.json()}\n\n"
205
+ else:
206
+ # Send regular content chunk
207
+ chunk = ChatCompletionChunk(
208
+ id=request_id,
209
+ created=int(time.time()),
210
+ model=replicate_model_id,
211
+ choices=[ChoiceDelta(
212
+ index=0,
213
+ delta=DeltaMessage(content=content_token),
214
+ finish_reason=None
215
+ )]
216
+ )
217
+ yield f"data: {chunk.json()}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
+ elif current_event == "done":
220
+ # Send final usage chunk
221
+ end_time = time.time()
222
+ inference_time = end_time - start_time
223
+
224
+ usage = Usage(
225
+ prompt_tokens=prompt_tokens,
226
+ completion_tokens=completion_tokens,
227
+ total_tokens=prompt_tokens + completion_tokens,
228
+ inference_time=round(inference_time, 3)
229
+ )
230
+
231
+ usage_chunk = ChatCompletionChunk(
232
+ id=request_id,
233
+ created=int(time.time()),
234
+ model=replicate_model_id,
235
+ choices=[ChoiceDelta(
236
+ index=0,
237
+ delta=DeltaMessage(),
238
+ finish_reason="stop"
239
+ )],
240
+ usage=usage
241
+ )
242
+ yield f"data: {usage_chunk.json()}\n\n"
243
+ break
244
+
245
  except httpx.ReadTimeout:
246
+ yield f"data: {json.dumps({'error': {'message': 'Stream timed out.', 'type': 'timeout_error'}})}\n\n"
247
  return
248
 
249
+ yield "data: [DONE]\n\n"
 
250
 
251
  # --- Endpoints ---
252
  @app.get("/v1/models")
 
254
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
255
 
256
  @app.post("/v1/chat/completions")
257
+ async def create_chat_completion(request: ChatCompletionRequest):
258
  if request.model not in SUPPORTED_MODELS:
259
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
260
 
261
+ # Format messages for Replicate
262
+ formatted = format_messages_for_replicate(request.messages, request.functions)
263
+ replicate_input = {
264
+ "prompt": formatted["prompt"],
265
+ "max_new_tokens": request.max_tokens or 512,
266
+ "temperature": request.temperature or 0.7,
267
+ "top_p": request.top_p or 1.0
268
+ }
269
+ if formatted["system_prompt"]:
270
+ replicate_input["system_prompt"] = formatted["system_prompt"]
271
+ if formatted["image"]:
272
+ replicate_input["image"] = formatted["image"]
273
+
274
+ request_id = f"chatcmpl-{int(time.time())}"
275
+
276
  if request.stream:
277
+ return StreamingResponse(
278
+ stream_replicate_response(SUPPORTED_MODELS[request.model], replicate_input, request_id),
279
+ media_type="text/event-stream"
280
+ )
281
 
282
+ # Non-streaming response
283
  url = f"https://api.replicate.com/v1/models/{SUPPORTED_MODELS[request.model]}/predictions"
284
+ headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
285
  start_time = time.time()
286
 
287
  async with httpx.AsyncClient() as client:
288
  try:
289
+ resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=300.0)
290
  resp.raise_for_status()
291
  pred = resp.json()
292
  output = "".join(pred.get("output", []))
 
294
  # Calculate timing and tokens
295
  end_time = time.time()
296
  inference_time = end_time - start_time
297
+ prompt_tokens = len(replicate_input.get("prompt", "")) // 4
298
+ completion_tokens = len(output) // 4
299
+
300
+ # Parse function call if present
301
+ tool_calls = None
302
+ function_call = parse_function_call(output)
303
+ if function_call:
304
+ tool_call = ToolCall(
305
+ id=f"call_{int(time.time())}",
306
+ function=FunctionCall(
307
+ name=function_call["name"],
308
+ arguments=function_call["arguments"]
309
+ )
310
+ )
311
+ tool_calls = [tool_call]
312
 
313
+ return ChatCompletion(
314
+ id=request_id,
315
+ created=int(time.time()),
316
+ model=request.model,
317
+ choices=[Choice(
318
+ index=0,
319
+ message=ChatMessage(
320
+ role="assistant",
321
+ content=output if not function_call else None,
322
+ tool_calls=tool_calls
323
+ ),
324
+ finish_reason="tool_calls" if function_call else "stop"
325
+ )],
326
+ usage=Usage(
327
+ prompt_tokens=prompt_tokens,
328
+ completion_tokens=completion_tokens,
329
+ total_tokens=prompt_tokens + completion_tokens,
330
+ inference_time=round(inference_time, 3)
331
+ )
332
+ )
333
  except httpx.HTTPStatusError as e:
334
  raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")
335
+ except Exception as e:
336
+ raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
337
+
338
+ @app.get("/")
339
+ async def root():
340
+ return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.0"}
341
+
342
+ # Performance optimization middleware
343
+ @app.middleware("http")
344
+ async def add_performance_headers(request, call_next):
345
+ start_time = time.time()
346
+ response = await call_next(request)
347
+ process_time = time.time() - start_time
348
+ response.headers["X-Process-Time"] = str(round(process_time, 3))
349
+ response.headers["X-API-Version"] = "9.2.0"
350
+ return response