rkihacker commited on
Commit
2fc646f
·
verified ·
1 Parent(s): ea53c08

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +51 -70
main.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import httpx
3
  import json
4
  import time
5
- import asyncio
6
  from fastapi import FastAPI, HTTPException
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel, Field
@@ -21,101 +20,86 @@ if not REPLICATE_API_TOKEN:
21
  # --- FastAPI App Initialization ---
22
  app = FastAPI(
23
  title="Replicate to OpenAI Compatibility Layer",
24
- version="2.0.0 (Native Streaming & Context Fixed)",
25
  )
26
 
27
- # --- Pydantic Models for OpenAI Compatibility ---
28
-
29
  class ModelCard(BaseModel):
30
- id: str
31
- object: str = "model"
32
- created: int = Field(default_factory=lambda: int(time.time()))
33
- owned_by: str = "replicate"
34
 
35
  class ModelList(BaseModel):
36
- object: str = "list"
37
- data: List[ModelCard] = []
38
 
39
  class ChatMessage(BaseModel):
40
- role: Literal["system", "user", "assistant", "tool"]
41
- content: Union[str, List[Dict[str, Any]]]
42
-
43
- class ToolFunction(BaseModel):
44
- name: str
45
- description: str
46
- parameters: Dict[str, Any]
47
-
48
- class Tool(BaseModel):
49
- type: Literal["function"]
50
- function: ToolFunction
51
 
52
  class OpenAIChatCompletionRequest(BaseModel):
53
- model: str
54
- messages: List[ChatMessage]
55
- temperature: Optional[float] = 0.7
56
- top_p: Optional[float] = 1.0
57
- max_tokens: Optional[int] = None
58
- stream: Optional[bool] = False
59
- tools: Optional[List[Tool]] = None
60
- tool_choice: Optional[Union[str, Dict]] = None
61
-
62
- # --- Replicate Model Mapping ---
63
  SUPPORTED_MODELS = {
64
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
65
  "claude-4.5-haiku": "anthropic/claude-4.5-haiku"
66
  }
67
 
68
-
69
  # --- Helper Functions ---
70
 
71
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
72
  """
73
- Prepares the input payload for Replicate's chat models.
74
- This now correctly passes the messages array for context.
75
  """
76
- # Convert Pydantic message objects to a list of dictionaries
77
- messages_for_replicate = [msg.dict() for msg in request.messages]
78
-
79
- payload = {
80
- "messages": messages_for_replicate
81
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # Add other compatible parameters
84
  if request.max_tokens is not None:
85
  payload["max_new_tokens"] = request.max_tokens
86
  if request.temperature is not None:
87
  payload["temperature"] = request.temperature
88
  if request.top_p is not None:
89
  payload["top_p"] = request.top_p
90
-
91
- # Vision support: Find image URL in the last user message if present
92
- last_user_message = next((m for m in reversed(request.messages) if m.role == 'user'), None)
93
- if last_user_message and isinstance(last_user_message.content, list):
94
- for item in last_user_message.content:
95
- if item.get("type") == "image_url":
96
- payload["image"] = item.get("image_url", {}).get("url")
97
- # Reformat messages to be a simple prompt string for vision models if needed,
98
- # as some might not support the `messages` format with images.
99
- # For Claude Haiku, a prompt string is more reliable with images.
100
- if "claude" in request.model:
101
- text_prompts = [item.get('text', '') for item in last_user_message.content if item.get('type') == 'text']
102
- payload["prompt"] = " ".join(text_prompts)
103
- del payload["messages"]
104
- break
105
 
106
  return payload
107
 
108
  async def stream_replicate_native_sse(model_id: str, payload: dict):
109
- """
110
- Connects to Replicate's native SSE stream for true token-by-token streaming.
111
- """
112
  url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
113
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
114
 
115
  async with httpx.AsyncClient(timeout=300) as client:
116
- # 1. Create the prediction to get the stream URL
117
  try:
118
- # Add stream=True to the outer payload for Replicate
119
  response = await client.post(url, headers=headers, json={"input": payload, "stream": True})
120
  response.raise_for_status()
121
  prediction = response.json()
@@ -126,10 +110,13 @@ async def stream_replicate_native_sse(model_id: str, payload: dict):
126
  yield json.dumps({"error": {"message": error_detail}})
127
  return
128
  except httpx.HTTPStatusError as e:
129
- yield json.dumps({"error": {"message": e.response.text}})
 
 
 
 
130
  return
131
 
132
- # 2. Connect to the SSE stream and yield OpenAI-compatible chunks
133
  try:
134
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}) as sse:
135
  sse.raise_for_status()
@@ -146,11 +133,10 @@ async def stream_replicate_native_sse(model_id: str, payload: dict):
146
  }
147
  yield json.dumps(chunk)
148
  elif current_event == "done":
149
- break # Exit loop when done event is received
150
  except Exception as e:
151
  yield json.dumps({"error": {"message": f"Streaming error: {str(e)}"}})
152
 
153
- # 3. Send the final DONE chunk
154
  done_chunk = {
155
  "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
156
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
@@ -158,9 +144,7 @@ async def stream_replicate_native_sse(model_id: str, payload: dict):
158
  yield json.dumps(done_chunk)
159
  yield "[DONE]"
160
 
161
-
162
  # --- API Endpoints ---
163
-
164
  @app.get("/v1/models", response_model=ModelList)
165
  async def list_models():
166
  return ModelList(data=[ModelCard(id=model_name) for model_name in SUPPORTED_MODELS.keys()])
@@ -186,14 +170,11 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
186
  response = await client.post(url, headers=headers, json={"input": replicate_input})
187
  response.raise_for_status()
188
  prediction = response.json()
189
-
190
  output = "".join(prediction.get("output", []))
191
-
192
  return JSONResponse(content={
193
  "id": prediction["id"], "object": "chat.completion", "created": int(time.time()), "model": model_key,
194
  "choices": [{"index": 0, "message": {"role": "assistant", "content": output}, "finish_reason": "stop"}],
195
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
196
  })
197
-
198
  except httpx.HTTPStatusError as e:
199
  raise HTTPException(status_code=e.response.status_code, detail=e.response.text)
 
2
  import httpx
3
  import json
4
  import time
 
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.responses import JSONResponse
7
  from pydantic import BaseModel, Field
 
20
  # --- FastAPI App Initialization ---
21
  app = FastAPI(
22
  title="Replicate to OpenAI Compatibility Layer",
23
+ version="2.1.0 (Model Input Fixed)",
24
  )
25
 
26
+ # --- Pydantic Models ---
 
27
  class ModelCard(BaseModel):
28
+ id: str; object: str = "model"; created: int = Field(default_factory=lambda: int(time.time())); owned_by: str = "replicate"
 
 
 
29
 
30
  class ModelList(BaseModel):
31
+ object: str = "list"; data: List[ModelCard] = []
 
32
 
33
  class ChatMessage(BaseModel):
34
+ role: Literal["system", "user", "assistant", "tool"]; content: Union[str, List[Dict[str, Any]]]
 
 
 
 
 
 
 
 
 
 
35
 
36
  class OpenAIChatCompletionRequest(BaseModel):
37
+ model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
38
+
39
+ # --- Model Mapping ---
 
 
 
 
 
 
 
40
  SUPPORTED_MODELS = {
41
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
42
  "claude-4.5-haiku": "anthropic/claude-4.5-haiku"
43
  }
44
 
 
45
  # --- Helper Functions ---
46
 
47
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
48
  """
49
+ Prepares the input payload for Replicate, handling model-specific formats.
 
50
  """
51
+ payload = {}
52
+
53
+ # *** THIS IS THE CRITICAL FIX ***
54
+ # Claude models on Replicate require a single 'prompt' string.
55
+ # We must convert the 'messages' array into a formatted string.
56
+ if "claude" in request.model:
57
+ prompt_parts = []
58
+ system_prompt = None
59
+ image_url = None
60
+
61
+ for msg in request.messages:
62
+ if msg.role == "system":
63
+ system_prompt = str(msg.content)
64
+ elif msg.role == "user":
65
+ if isinstance(msg.content, list): # Vision case
66
+ for item in msg.content:
67
+ if item.get("type") == "text":
68
+ prompt_parts.append(f"User: {item.get('text', '')}")
69
+ elif item.get("type") == "image_url":
70
+ image_url = item.get("image_url", {}).get("url")
71
+ else: # Text-only case
72
+ prompt_parts.append(f"User: {msg.content}")
73
+ elif msg.role == "assistant":
74
+ prompt_parts.append(f"Assistant: {msg.content}")
75
+
76
+ payload["prompt"] = "\n".join(prompt_parts)
77
+ if system_prompt:
78
+ payload["system_prompt"] = system_prompt
79
+ if image_url:
80
+ payload["image"] = image_url
81
+
82
+ # Other models like Llama-3 accept the 'messages' array directly.
83
+ else:
84
+ payload["messages"] = [msg.dict() for msg in request.messages]
85
 
86
+ # Add common parameters
87
  if request.max_tokens is not None:
88
  payload["max_new_tokens"] = request.max_tokens
89
  if request.temperature is not None:
90
  payload["temperature"] = request.temperature
91
  if request.top_p is not None:
92
  payload["top_p"] = request.top_p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  return payload
95
 
96
  async def stream_replicate_native_sse(model_id: str, payload: dict):
97
+ """Connects to Replicate's native SSE stream for token-by-token streaming."""
 
 
98
  url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
99
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
100
 
101
  async with httpx.AsyncClient(timeout=300) as client:
 
102
  try:
 
103
  response = await client.post(url, headers=headers, json={"input": payload, "stream": True})
104
  response.raise_for_status()
105
  prediction = response.json()
 
110
  yield json.dumps({"error": {"message": error_detail}})
111
  return
112
  except httpx.HTTPStatusError as e:
113
+ try:
114
+ error_body = e.response.json()
115
+ yield json.dumps({"error": {"message": json.dumps(error_body)}})
116
+ except json.JSONDecodeError:
117
+ yield json.dumps({"error": {"message": e.response.text}})
118
  return
119
 
 
120
  try:
121
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}) as sse:
122
  sse.raise_for_status()
 
133
  }
134
  yield json.dumps(chunk)
135
  elif current_event == "done":
136
+ break
137
  except Exception as e:
138
  yield json.dumps({"error": {"message": f"Streaming error: {str(e)}"}})
139
 
 
140
  done_chunk = {
141
  "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
142
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
 
144
  yield json.dumps(done_chunk)
145
  yield "[DONE]"
146
 
 
147
  # --- API Endpoints ---
 
148
  @app.get("/v1/models", response_model=ModelList)
149
  async def list_models():
150
  return ModelList(data=[ModelCard(id=model_name) for model_name in SUPPORTED_MODELS.keys()])
 
170
  response = await client.post(url, headers=headers, json={"input": replicate_input})
171
  response.raise_for_status()
172
  prediction = response.json()
 
173
  output = "".join(prediction.get("output", []))
 
174
  return JSONResponse(content={
175
  "id": prediction["id"], "object": "chat.completion", "created": int(time.time()), "model": model_key,
176
  "choices": [{"index": 0, "message": {"role": "assistant", "content": output}, "finish_reason": "stop"}],
177
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
178
  })
 
179
  except httpx.HTTPStatusError as e:
180
  raise HTTPException(status_code=e.response.status_code, detail=e.response.text)