rkihacker commited on
Commit
54de3fd
·
verified ·
1 Parent(s): a135be4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +33 -31
main.py CHANGED
@@ -16,7 +16,7 @@ if not REPLICATE_API_TOKEN:
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="7.0.0 (Unified Prompt Fix)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
@@ -39,51 +39,41 @@ SUPPORTED_MODELS = {
39
  # --- Core Logic ---
40
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
41
  """
42
- Formats the input for the Replicate API. This function now uses a unified approach
43
- for all models, flattening the message history into a single 'prompt' string
44
- and handling images separately, as required by Replicate's API.
45
  """
46
  payload = {}
47
-
48
  prompt_parts = []
49
  system_prompt = None
50
  image_input = None
51
 
52
  for msg in request.messages:
53
  if msg.role == "system":
54
- # Extract system prompt; it will be a separate parameter.
55
  system_prompt = str(msg.content)
56
  elif msg.role == "assistant":
57
  prompt_parts.append(f"Assistant: {msg.content}")
58
  elif msg.role == "user":
59
  user_text_content = ""
60
  if isinstance(msg.content, list):
61
- # Handle multimodal (vision) input from OpenAI format
62
  for item in msg.content:
63
  if item.get("type") == "text":
64
  user_text_content += item.get("text", "")
65
  elif item.get("type") == "image_url":
66
  image_url_data = item.get("image_url", {})
67
- # The 'image' parameter is used by Claude, Llava, etc., on Replicate
68
  image_input = image_url_data.get("url")
69
  else:
70
  user_text_content = str(msg.content)
71
-
72
  prompt_parts.append(f"User: {user_text_content}")
73
 
74
- # The final "Assistant:" turn prompts the model for a response.
75
  prompt_parts.append("Assistant:")
76
-
77
- # All models on Replicate's API expect a single 'prompt' string.
78
  payload["prompt"] = "\n\n".join(prompt_parts)
79
 
80
  if system_prompt:
81
  payload["system_prompt"] = system_prompt
82
-
83
  if image_input:
84
  payload["image"] = image_input
85
 
86
- # Map common OpenAI parameters to Replicate equivalents
87
  if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
88
  if request.temperature: payload["temperature"] = request.temperature
89
  if request.top_p: payload["top_p"] = request.top_p
@@ -91,7 +81,7 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
91
  return payload
92
 
93
  async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
94
- """Handles the full streaming lifecycle using standard Replicate endpoints."""
95
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
96
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
97
 
@@ -102,20 +92,17 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
102
  prediction = response.json()
103
  stream_url = prediction.get("urls", {}).get("stream")
104
  prediction_id = prediction.get("id", "stream-unknown")
105
-
106
  if not stream_url:
107
- yield json.dumps({"error": {"message": "Model did not return a stream URL."}})
108
- return
109
-
110
  except httpx.HTTPStatusError as e:
111
- error_details = e.response.text
112
- try:
113
- error_json = e.response.json()
114
- error_details = error_json.get("detail", error_details)
115
- except json.JSONDecodeError:
116
- pass
117
- yield json.dumps({"error": {"message": f"Upstream Error: {error_details}", "type": "replicate_error"}})
118
- return
119
 
120
  try:
121
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
@@ -126,12 +113,29 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
126
  elif line.startswith("data:"):
127
  data = line[len("data:"):].strip()
128
  if current_event == "output":
129
- if data:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  chunk = {
131
  "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
132
- "choices": [{"index": 0, "delta": {"content": data}, "finish_reason": None}]
133
  }
134
  yield json.dumps(chunk)
 
135
  elif current_event == "done":
136
  break
137
  except httpx.ReadTimeout:
@@ -148,12 +152,10 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
148
  # --- Endpoints ---
149
  @app.get("/v1/models")
150
  async def list_models():
151
- """Lists the currently supported models."""
152
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
153
 
154
  @app.post("/v1/chat/completions")
155
  async def create_chat_completion(request: OpenAIChatCompletionRequest):
156
- """Handles chat completion requests, streaming or non-streaming."""
157
  if request.model not in SUPPORTED_MODELS:
158
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
159
 
 
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="7.1.0 (Streaming Space Fix)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
 
39
  # --- Core Logic ---
40
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
41
  """
42
+ Formats the input for Replicate's API, flattening the message history into a
43
+ single 'prompt' string and handling images separately. This is the required
44
+ format for all their current chat/vision models.
45
  """
46
  payload = {}
 
47
  prompt_parts = []
48
  system_prompt = None
49
  image_input = None
50
 
51
  for msg in request.messages:
52
  if msg.role == "system":
 
53
  system_prompt = str(msg.content)
54
  elif msg.role == "assistant":
55
  prompt_parts.append(f"Assistant: {msg.content}")
56
  elif msg.role == "user":
57
  user_text_content = ""
58
  if isinstance(msg.content, list):
 
59
  for item in msg.content:
60
  if item.get("type") == "text":
61
  user_text_content += item.get("text", "")
62
  elif item.get("type") == "image_url":
63
  image_url_data = item.get("image_url", {})
 
64
  image_input = image_url_data.get("url")
65
  else:
66
  user_text_content = str(msg.content)
 
67
  prompt_parts.append(f"User: {user_text_content}")
68
 
 
69
  prompt_parts.append("Assistant:")
 
 
70
  payload["prompt"] = "\n\n".join(prompt_parts)
71
 
72
  if system_prompt:
73
  payload["system_prompt"] = system_prompt
 
74
  if image_input:
75
  payload["image"] = image_input
76
 
 
77
  if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
78
  if request.temperature: payload["temperature"] = request.temperature
79
  if request.top_p: payload["top_p"] = request.top_p
 
81
  return payload
82
 
83
  async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
84
+ """Handles the full streaming lifecycle with robust token parsing."""
85
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
86
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
87
 
 
92
  prediction = response.json()
93
  stream_url = prediction.get("urls", {}).get("stream")
94
  prediction_id = prediction.get("id", "stream-unknown")
 
95
  if not stream_url:
96
+ yield json.dumps({"error": {"message": "Model did not return a stream URL."}})
97
+ return
 
98
  except httpx.HTTPStatusError as e:
99
+ error_details = e.response.text
100
+ try:
101
+ error_json = e.response.json()
102
+ error_details = error_json.get("detail", error_details)
103
+ except json.JSONDecodeError: pass
104
+ yield json.dumps({"error": {"message": f"Upstream Error: {error_details}", "type": "replicate_error"}})
105
+ return
 
106
 
107
  try:
108
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
 
113
  elif line.startswith("data:"):
114
  data = line[len("data:"):].strip()
115
  if current_event == "output":
116
+ # --- START OF STREAMING FIX ---
117
+ # Replicate streams tokens that can be plain text or JSON-encoded strings.
118
+ # We need to robustly parse them to preserve spaces correctly.
119
+ content_token = ""
120
+ try:
121
+ # Attempt to parse data as JSON. This handles tokens like "\" Hello\""
122
+ decoded_data = json.loads(data)
123
+ if isinstance(decoded_data, str):
124
+ content_token = decoded_data
125
+ else:
126
+ # It's some other JSON type, convert to string
127
+ content_token = str(decoded_data)
128
+ except json.JSONDecodeError:
129
+ # It's not valid JSON, so it's a plain text token.
130
+ content_token = data
131
+
132
+ if content_token:
133
  chunk = {
134
  "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
135
+ "choices": [{"index": 0, "delta": {"content": content_token}, "finish_reason": None}]
136
  }
137
  yield json.dumps(chunk)
138
+ # --- END OF STREAMING FIX ---
139
  elif current_event == "done":
140
  break
141
  except httpx.ReadTimeout:
 
152
  # --- Endpoints ---
153
  @app.get("/v1/models")
154
  async def list_models():
 
155
  return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
156
 
157
  @app.post("/v1/chat/completions")
158
  async def create_chat_completion(request: OpenAIChatCompletionRequest):
 
159
  if request.model not in SUPPORTED_MODELS:
160
  raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
161