rkihacker commited on
Commit
dafbe9c
·
verified ·
1 Parent(s): 7b0c05f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +98 -113
main.py CHANGED
@@ -1,13 +1,14 @@
 
1
  import os
2
  import httpx
3
  import json
4
  import time
5
  from fastapi import FastAPI, HTTPException
6
- from fastapi.responses import Response
7
  from pydantic import BaseModel, Field
8
  from typing import List, Dict, Any, Optional, Union, Literal
9
  from dotenv import load_dotenv
10
- import asyncio
11
 
12
  # Load environment variables
13
  load_dotenv()
@@ -16,36 +17,23 @@ if not REPLICATE_API_TOKEN:
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="10.0.0 (Enhanced Chunk Formatting)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
23
- id: str
24
- object: str = "model"
25
- created: int = Field(default_factory=lambda: int(time.time()))
26
- owned_by: str = "replicate"
27
-
28
  class ModelList(BaseModel):
29
- object: str = "list"
30
- data: List[ModelCard] = []
31
-
32
  class ChatMessage(BaseModel):
33
- role: Literal["system", "user", "assistant", "tool"]
34
- content: Union[str, List[Dict[str, Any]]]
35
-
36
  class OpenAIChatCompletionRequest(BaseModel):
37
- model: str
38
- messages: List[ChatMessage]
39
- temperature: Optional[float] = 0.7
40
- top_p: Optional[float] = 1.0
41
- max_tokens: Optional[int] = None
42
- stream: Optional[bool] = False
43
 
44
  # --- Supported Models ---
45
  SUPPORTED_MODELS = {
46
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
47
- "claude-4.5-haiku": "anthropic/claude-4.5-haiku", # Note: Name changed for clarity
48
- "claude-4.5-sonnet": "anthropic/claude-4.5-sonnet", # Note: Name changed for clarity
49
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
50
  }
51
 
@@ -92,136 +80,136 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
92
 
93
  return payload
94
 
95
- def get_provider(replicate_model_id: str) -> str:
96
- """Infers the provider from the Replicate model ID."""
97
- if replicate_model_id.startswith("meta/"):
98
- return "Meta"
99
- if replicate_model_id.startswith("anthropic/"):
100
- return "Anthropic"
101
- if "llava" in replicate_model_id:
102
- return "Llava"
103
- return "Replicate"
104
-
105
- async def stream_replicate_sse(replicate_model_id: str, requested_model_name: str, input_payload: dict):
106
- """
107
- Handles the full streaming lifecycle with corrected whitespace preservation
108
- and the new, detailed chunk format.
109
- """
110
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
111
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
112
 
113
- # Identify provider for the response chunks
114
- provider = get_provider(replicate_model_id)
115
-
116
  async with httpx.AsyncClient(timeout=60.0) as client:
117
- # 1. Create the prediction and get the stream URL
118
  try:
119
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
120
  response.raise_for_status()
121
  prediction = response.json()
122
  stream_url = prediction.get("urls", {}).get("stream")
123
- prediction_id = prediction.get("id", f"stream-{int(time.time())}")
124
-
125
  if not stream_url:
126
- error_chunk = { "error": {"message": "Model did not return a stream URL."} }
127
- yield f"data: {json.dumps(error_chunk)}\n\n"
128
  return
129
-
130
  except httpx.HTTPStatusError as e:
131
  error_details = e.response.text
132
  try:
133
  error_json = e.response.json()
134
  error_details = error_json.get("detail", error_details)
135
  except json.JSONDecodeError: pass
136
- error_chunk = {"error": {"message": f"Upstream Error: {error_details}", "type": "replicate_error"}}
137
- yield f"data: {json.dumps(error_chunk)}\n\n"
138
  return
139
-
140
- # 2. Connect to the SSE stream and yield formatted chunks
141
  try:
142
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
143
  current_event = None
144
  async for line in sse.aiter_lines():
145
- if not line:
146
  continue
147
  if line.startswith("event:"):
148
  current_event = line[len("event:"):].strip()
149
  elif line.startswith("data:"):
150
- # Get the raw payload after "data:"
151
- raw_payload = line[len("data:"):]
 
 
 
 
 
 
 
152
 
153
- # The SSE spec allows an optional leading space. Remove it.
154
- # This robustly prevents parsing errors without destroying content.
155
- payload = raw_payload.lstrip(" ")
156
-
157
  if current_event == "output":
158
- if not payload:
159
  continue
160
-
161
  content_token = ""
162
  try:
163
- # This handles JSON-encoded strings like "\" Hello\"" and correctly
164
- # preserves all whitespace, including single spaces. This is the fix.
165
- content_token = json.loads(payload)
166
  except (json.JSONDecodeError, TypeError):
167
- # Fallback for plain text tokens if Replicate changes format
168
- content_token = payload
169
 
170
- # Build the new, detailed chunk structure
171
  chunk = {
172
- "id": prediction_id,
173
- "object": "chat.completion.chunk",
174
- "created": int(time.time()),
175
- "model": requested_model_name,
176
- "provider": provider,
177
  "choices": [{
178
- "index": 0,
179
  "delta": {"content": content_token},
180
  "finish_reason": None,
 
181
  "logprobs": None,
182
  "native_finish_reason": None
183
- }]
 
 
 
 
 
184
  }
185
- yield f"data: {json.dumps(chunk)}\n\n"
186
-
 
187
  elif current_event == "done":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  break
189
  except httpx.ReadTimeout:
190
- error_chunk = {"error": {"message": "Stream timed out.", "type": "timeout_error"}}
191
- yield f"data: {json.dumps(error_chunk)}\n\n"
192
  return
193
 
194
- # 3. Send the final chunk with finish_reason
195
- final_chunk = {
196
- "id": prediction_id,
197
- "object": "chat.completion.chunk",
198
- "created": int(time.time()),
199
- "model": requested_model_name,
200
- "provider": provider,
201
- "choices": [{
202
- "index": 0,
203
- "delta": {},
204
- "finish_reason": "stop",
205
- "logprobs": None,
206
- "native_finish_reason": "end_turn"
207
- }]
208
- }
209
- yield f"data: {json.dumps(final_chunk)}\n\n"
210
- yield "data: [DONE]\n\n"
211
-
212
- # A simple EventSourceResponse implementation if sse-starlette is not preferred
213
- async def create_sse_response(generator):
214
- headers = {
215
- 'Content-Type': 'text/event-stream',
216
- 'Cache-Control': 'no-cache',
217
- 'Connection': 'keep-alive',
218
- }
219
- async def stream():
220
- async for chunk in generator:
221
- yield chunk
222
- await asyncio.sleep(0) # Yield control to the event loop
223
- return Response(stream(), headers=headers)
224
-
225
 
226
  # --- Endpoints ---
227
  @app.get("/v1/models")
@@ -233,16 +221,13 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
233
  if request.model not in SUPPORTED_MODELS:
234
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
235
 
236
- replicate_model_id = SUPPORTED_MODELS[request.model]
237
  replicate_input = prepare_replicate_input(request)
238
 
239
  if request.stream:
240
- # Use the custom generator with the detailed chunk format
241
- generator = stream_replicate_sse(replicate_model_id, request.model, replicate_input)
242
- return await create_sse_response(generator)
243
 
244
  # Non-streaming fallback
245
- url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
246
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
247
  async with httpx.AsyncClient() as client:
248
  try:
@@ -256,4 +241,4 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
256
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
257
  }
258
  except httpx.HTTPStatusError as e:
259
- raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")
 
1
+
2
  import os
3
  import httpx
4
  import json
5
  import time
6
  from fastapi import FastAPI, HTTPException
7
+ from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel, Field
9
  from typing import List, Dict, Any, Optional, Union, Literal
10
  from dotenv import load_dotenv
11
+ from sse_starlette.sse import EventSourceResponse
12
 
13
  # Load environment variables
14
  load_dotenv()
 
17
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
18
 
19
  # FastAPI Init
20
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.0.0 (Definitive Streaming Fix)")
21
 
22
  # --- Pydantic Models ---
23
  class ModelCard(BaseModel):
24
+ id: str; object: str = "model"; created: int = Field(default_factory=lambda: int(time.time())); owned_by: str = "replicate"
 
 
 
 
25
  class ModelList(BaseModel):
26
+ object: str = "list"; data: List[ModelCard] = []
 
 
27
  class ChatMessage(BaseModel):
28
+ role: Literal["system", "user", "assistant", "tool"]; content: Union[str, List[Dict[str, Any]]]
 
 
29
  class OpenAIChatCompletionRequest(BaseModel):
30
+ model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
 
 
 
 
 
31
 
32
  # --- Supported Models ---
33
  SUPPORTED_MODELS = {
34
  "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
35
+ "claude-4.5-haiku": "anthropic/claude-4.5-haiku",
36
+ "claude-4.5-sonnet": "anthropic/claude-4.5-sonnet",
37
  "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
38
  }
39
 
 
80
 
81
  return payload
82
 
83
+ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
84
+ """Handles the full streaming lifecycle with correct whitespace preservation."""
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
86
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
87
 
 
 
 
88
  async with httpx.AsyncClient(timeout=60.0) as client:
 
89
  try:
90
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
91
  response.raise_for_status()
92
  prediction = response.json()
93
  stream_url = prediction.get("urls", {}).get("stream")
94
+ prediction_id = prediction.get("id", "stream-unknown")
 
95
  if not stream_url:
96
+ yield json.dumps({'error': {'message': 'Model did not return a stream URL.'}})
 
97
  return
 
98
  except httpx.HTTPStatusError as e:
99
  error_details = e.response.text
100
  try:
101
  error_json = e.response.json()
102
  error_details = error_json.get("detail", error_details)
103
  except json.JSONDecodeError: pass
104
+ yield json.dumps({'error': {'message': f'Upstream Error: {error_details}', 'type': 'replicate_error'}})
 
105
  return
106
+
 
107
  try:
108
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
109
  current_event = None
110
  async for line in sse.aiter_lines():
111
+ if not line: # Skip empty lines
112
  continue
113
  if line.startswith("event:"):
114
  current_event = line[len("event:"):].strip()
115
  elif line.startswith("data:"):
116
+ # FIXED: Preserve all whitespace including leading/trailing spaces
117
+ raw_data = line[5:] # Remove "data:" prefix
118
+
119
+ # Remove only the optional single space after data: if present
120
+ # This is per SSE spec and preserves actual content spaces
121
+ if raw_data.startswith(" "):
122
+ data_content = raw_data[1:] # Remove the first space only
123
+ else:
124
+ data_content = raw_data
125
 
 
 
 
 
126
  if current_event == "output":
127
+ if not data_content:
128
  continue
129
+
130
  content_token = ""
131
  try:
132
+ # Handle JSON-encoded strings properly (including spaces)
133
+ content_token = json.loads(data_content)
 
134
  except (json.JSONDecodeError, TypeError):
135
+ # Handle plain text tokens (preserve as-is)
136
+ content_token = data_content
137
 
138
+ # Create chunk with exact format you specified
139
  chunk = {
 
 
 
 
 
140
  "choices": [{
 
141
  "delta": {"content": content_token},
142
  "finish_reason": None,
143
+ "index": 0,
144
  "logprobs": None,
145
  "native_finish_reason": None
146
+ }],
147
+ "created": int(time.time()),
148
+ "id": f"gen-{int(time.time())}-{prediction_id[-12:]}", # Format like your example
149
+ "model": replicate_model_id,
150
+ "object": "chat.completion.chunk",
151
+ "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate"
152
  }
153
+ # FIXED: Yield only the JSON data, let EventSourceResponse handle the SSE formatting
154
+ yield json.dumps(chunk)
155
+
156
  elif current_event == "done":
157
+ # Send usage chunk before done
158
+ usage_chunk = {
159
+ "choices": [{
160
+ "delta": {},
161
+ "finish_reason": None,
162
+ "index": 0,
163
+ "logprobs": None,
164
+ "native_finish_reason": None
165
+ }],
166
+ "created": int(time.time()),
167
+ "id": f"gen-{int(time.time())}-{prediction_id[-12:]}",
168
+ "model": replicate_model_id,
169
+ "object": "chat.completion.chunk",
170
+ "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate",
171
+ "usage": {
172
+ "cache_discount": 0,
173
+ "completion_tokens": 0,
174
+ "completion_tokens_details": {"image_tokens": 0, "reasoning_tokens": 0},
175
+ "cost": 0,
176
+ "cost_details": {
177
+ "upstream_inference_completions_cost": 0,
178
+ "upstream_inference_cost": None,
179
+ "upstream_inference_prompt_cost": 0
180
+ },
181
+ "input_tokens": 0,
182
+ "is_byok": False,
183
+ "prompt_tokens": 0,
184
+ "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
185
+ "total_tokens": 0
186
+ }
187
+ }
188
+ yield json.dumps(usage_chunk)
189
+
190
+ # Send final chunk with stop reason
191
+ final_chunk = {
192
+ "choices": [{
193
+ "delta": {},
194
+ "finish_reason": "stop",
195
+ "index": 0,
196
+ "logprobs": None,
197
+ "native_finish_reason": "end_turn"
198
+ }],
199
+ "created": int(time.time()),
200
+ "id": f"gen-{int(time.time())}-{prediction_id[-12:]}",
201
+ "model": replicate_model_id,
202
+ "object": "chat.completion.chunk",
203
+ "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate"
204
+ }
205
+ yield json.dumps(final_chunk)
206
  break
207
  except httpx.ReadTimeout:
208
+ yield json.dumps({'error': {'message': 'Stream timed out.', 'type': 'timeout_error'}})
 
209
  return
210
 
211
+ # Send [DONE] event
212
+ yield "[DONE]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  # --- Endpoints ---
215
  @app.get("/v1/models")
 
221
  if request.model not in SUPPORTED_MODELS:
222
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
223
 
 
224
  replicate_input = prepare_replicate_input(request)
225
 
226
  if request.stream:
227
+ return EventSourceResponse(stream_replicate_sse(SUPPORTED_MODELS[request.model], replicate_input), media_type="text/event-stream")
 
 
228
 
229
  # Non-streaming fallback
230
+ url = f"https://api.replicate.com/v1/models/{SUPPORTED_MODELS[request.model]}/predictions"
231
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
232
  async with httpx.AsyncClient() as client:
233
  try:
 
241
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
242
  }
243
  except httpx.HTTPStatusError as e:
244
+ raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")