rkihacker commited on
Commit
de4d166
·
verified ·
1 Parent(s): f8a669a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +27 -35
main.py CHANGED
@@ -20,7 +20,7 @@ if not REPLICATE_API_TOKEN:
20
  # --- FastAPI App Initialization ---
21
  app = FastAPI(
22
  title="Replicate to OpenAI Compatibility Layer",
23
- version="2.2.0 (Stable Streaming)",
24
  )
25
 
26
  # --- Pydantic Models ---
@@ -52,41 +52,30 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
52
  prompt_parts = []
53
  system_prompt = None
54
  image_url = None
55
-
56
  for msg in request.messages:
57
  if msg.role == "system":
58
  system_prompt = str(msg.content)
59
  elif msg.role == "user":
60
- if isinstance(msg.content, list): # Vision case
61
  for item in msg.content:
62
  if item.get("type") == "text":
63
  prompt_parts.append(f"User: {item.get('text', '')}")
64
  elif item.get("type") == "image_url":
65
  image_url = item.get("image_url", {}).get("url")
66
- else: # Text-only case
67
  prompt_parts.append(f"User: {msg.content}")
68
  elif msg.role == "assistant":
69
  prompt_parts.append(f"Assistant: {msg.content}")
70
-
71
- # Add final turn for the assistant to respond
72
  prompt_parts.append("Assistant:")
73
-
74
  payload["prompt"] = "\n".join(prompt_parts)
75
- if system_prompt:
76
- payload["system_prompt"] = system_prompt
77
- if image_url:
78
- payload["image"] = image_url
79
-
80
- else: # Llama-3 and other standard chat models
81
  payload["messages"] = [msg.dict() for msg in request.messages]
82
 
83
- if request.max_tokens is not None:
84
- payload["max_new_tokens"] = request.max_tokens
85
- if request.temperature is not None:
86
- payload["temperature"] = request.temperature
87
- if request.top_p is not None:
88
- payload["top_p"] = request.top_p
89
-
90
  return payload
91
 
92
  async def stream_replicate_native_sse(model_id: str, payload: dict):
@@ -95,6 +84,7 @@ async def stream_replicate_native_sse(model_id: str, payload: dict):
95
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
96
 
97
  async with httpx.AsyncClient(timeout=300) as client:
 
98
  try:
99
  response = await client.post(url, headers=headers, json={"input": payload, "stream": True})
100
  response.raise_for_status()
@@ -106,11 +96,8 @@ async def stream_replicate_native_sse(model_id: str, payload: dict):
106
  yield json.dumps({"error": {"message": error_detail}})
107
  return
108
  except httpx.HTTPStatusError as e:
109
- try:
110
- error_body = e.response.json()
111
- yield json.dumps({"error": {"message": json.dumps(error_body)}})
112
- except json.JSONDecodeError:
113
- yield json.dumps({"error": {"message": e.response.text}})
114
  return
115
 
116
  try:
@@ -123,21 +110,27 @@ async def stream_replicate_native_sse(model_id: str, payload: dict):
123
  elif line.startswith("data:"):
124
  data = line[len("data:"):].strip()
125
 
126
- # *** THIS IS THE CRITICAL FIX ***
127
- # Only process non-empty data for 'output' events
128
- if data and current_event == "output":
129
- chunk = {
130
- "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
131
- "choices": [{"index": 0, "delta": {"content": json.loads(data)}, "finish_reason": None}]
132
- }
133
- yield json.dumps(chunk)
 
 
 
 
 
 
134
  elif current_event == "done":
135
  break
136
  except Exception as e:
137
  yield json.dumps({"error": {"message": f"Streaming error: {str(e)}"}})
138
 
139
  done_chunk = {
140
- "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
141
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
142
  }
143
  yield json.dumps(done_chunk)
@@ -160,7 +153,6 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
160
  if request.stream:
161
  return EventSourceResponse(stream_replicate_native_sse(replicate_model_id, replicate_input))
162
 
163
- # Synchronous request
164
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
165
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
166
 
 
20
  # --- FastAPI App Initialization ---
21
  app = FastAPI(
22
  title="Replicate to OpenAI Compatibility Layer",
23
+ version="2.3.0 (Definitive Streaming Fix)",
24
  )
25
 
26
  # --- Pydantic Models ---
 
52
  prompt_parts = []
53
  system_prompt = None
54
  image_url = None
 
55
  for msg in request.messages:
56
  if msg.role == "system":
57
  system_prompt = str(msg.content)
58
  elif msg.role == "user":
59
+ if isinstance(msg.content, list):
60
  for item in msg.content:
61
  if item.get("type") == "text":
62
  prompt_parts.append(f"User: {item.get('text', '')}")
63
  elif item.get("type") == "image_url":
64
  image_url = item.get("image_url", {}).get("url")
65
+ else:
66
  prompt_parts.append(f"User: {msg.content}")
67
  elif msg.role == "assistant":
68
  prompt_parts.append(f"Assistant: {msg.content}")
 
 
69
  prompt_parts.append("Assistant:")
 
70
  payload["prompt"] = "\n".join(prompt_parts)
71
+ if system_prompt: payload["system_prompt"] = system_prompt
72
+ if image_url: payload["image"] = image_url
73
+ else:
 
 
 
74
  payload["messages"] = [msg.dict() for msg in request.messages]
75
 
76
+ if request.max_tokens is not None: payload["max_new_tokens"] = request.max_tokens
77
+ if request.temperature is not None: payload["temperature"] = request.temperature
78
+ if request.top_p is not None: payload["top_p"] = request.top_p
 
 
 
 
79
  return payload
80
 
81
  async def stream_replicate_native_sse(model_id: str, payload: dict):
 
84
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
85
 
86
  async with httpx.AsyncClient(timeout=300) as client:
87
+ prediction = None
88
  try:
89
  response = await client.post(url, headers=headers, json={"input": payload, "stream": True})
90
  response.raise_for_status()
 
96
  yield json.dumps({"error": {"message": error_detail}})
97
  return
98
  except httpx.HTTPStatusError as e:
99
+ try: yield json.dumps({"error": {"message": json.dumps(e.response.json())}})
100
+ except: yield json.dumps({"error": {"message": e.response.text}})
 
 
 
101
  return
102
 
103
  try:
 
110
  elif line.startswith("data:"):
111
  data = line[len("data:"):].strip()
112
 
113
+ if current_event == "output":
114
+ # *** THIS IS THE DEFINITIVE FIX ***
115
+ # Wrap the JSON parsing in a try-except block to gracefully
116
+ # handle empty or malformed data lines without crashing.
117
+ try:
118
+ content = json.loads(data)
119
+ chunk = {
120
+ "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
121
+ "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}]
122
+ }
123
+ yield json.dumps(chunk)
124
+ except json.JSONDecodeError:
125
+ # This will silently ignore any non-JSON data, like empty strings.
126
+ pass
127
  elif current_event == "done":
128
  break
129
  except Exception as e:
130
  yield json.dumps({"error": {"message": f"Streaming error: {str(e)}"}})
131
 
132
  done_chunk = {
133
+ "id": prediction["id"] if prediction else "unknown", "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
134
  "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
135
  }
136
  yield json.dumps(done_chunk)
 
153
  if request.stream:
154
  return EventSourceResponse(stream_replicate_native_sse(replicate_model_id, replicate_input))
155
 
 
156
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
157
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
158