rkihacker commited on
Commit
63b36c2
·
verified ·
1 Parent(s): de4d166

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +49 -43
main.py CHANGED
@@ -20,7 +20,7 @@ if not REPLICATE_API_TOKEN:
20
  # --- FastAPI App Initialization ---
21
  app = FastAPI(
22
  title="Replicate to OpenAI Compatibility Layer",
23
- version="2.3.0 (Definitive Streaming Fix)",
24
  )
25
 
26
  # --- Pydantic Models ---
@@ -36,57 +36,66 @@ class ChatMessage(BaseModel):
36
  class OpenAIChatCompletionRequest(BaseModel):
37
  model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
38
 
39
- # --- Model Mapping ---
40
  SUPPORTED_MODELS = {
41
- "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
42
- "claude-4.5-haiku": "anthropic/claude-4.5-haiku"
 
 
 
 
 
 
 
 
43
  }
44
 
45
  # --- Helper Functions ---
46
 
47
- def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
48
- """Prepares the input payload for Replicate, handling model-specific formats."""
49
- payload = {}
50
 
51
- if "claude" in request.model:
 
52
  prompt_parts = []
53
  system_prompt = None
54
- image_url = None
55
  for msg in request.messages:
56
  if msg.role == "system":
57
  system_prompt = str(msg.content)
58
  elif msg.role == "user":
59
- if isinstance(msg.content, list):
60
- for item in msg.content:
61
- if item.get("type") == "text":
62
- prompt_parts.append(f"User: {item.get('text', '')}")
63
- elif item.get("type") == "image_url":
64
- image_url = item.get("image_url", {}).get("url")
65
- else:
66
- prompt_parts.append(f"User: {msg.content}")
67
  elif msg.role == "assistant":
68
  prompt_parts.append(f"Assistant: {msg.content}")
69
- prompt_parts.append("Assistant:")
70
- payload["prompt"] = "\n".join(prompt_parts)
71
- if system_prompt: payload["system_prompt"] = system_prompt
72
- if image_url: payload["image"] = image_url
73
- else:
74
- payload["messages"] = [msg.dict() for msg in request.messages]
75
-
76
- if request.max_tokens is not None: payload["max_new_tokens"] = request.max_tokens
77
- if request.temperature is not None: payload["temperature"] = request.temperature
78
- if request.top_p is not None: payload["top_p"] = request.top_p
79
- return payload
80
-
81
- async def stream_replicate_native_sse(model_id: str, payload: dict):
82
- """Connects to Replicate's native SSE stream for token-by-token streaming."""
83
- url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
 
 
 
 
 
84
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
85
 
 
 
 
86
  async with httpx.AsyncClient(timeout=300) as client:
87
  prediction = None
88
  try:
89
- response = await client.post(url, headers=headers, json={"input": payload, "stream": True})
90
  response.raise_for_status()
91
  prediction = response.json()
92
  stream_url = prediction.get("urls", {}).get("stream")
@@ -109,11 +118,7 @@ async def stream_replicate_native_sse(model_id: str, payload: dict):
109
  current_event = line[len("event:"):].strip()
110
  elif line.startswith("data:"):
111
  data = line[len("data:"):].strip()
112
-
113
  if current_event == "output":
114
- # *** THIS IS THE DEFINITIVE FIX ***
115
- # Wrap the JSON parsing in a try-except block to gracefully
116
- # handle empty or malformed data lines without crashing.
117
  try:
118
  content = json.loads(data)
119
  chunk = {
@@ -122,7 +127,7 @@ async def stream_replicate_native_sse(model_id: str, payload: dict):
122
  }
123
  yield json.dumps(chunk)
124
  except json.JSONDecodeError:
125
- # This will silently ignore any non-JSON data, like empty strings.
126
  pass
127
  elif current_event == "done":
128
  break
@@ -147,18 +152,19 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
147
  if model_key not in SUPPORTED_MODELS:
148
  raise HTTPException(status_code=404, detail=f"Model not found. Supported models: {list(SUPPORTED_MODELS.keys())}")
149
 
150
- replicate_model_id = SUPPORTED_MODELS[model_key]
151
- replicate_input = prepare_replicate_input(request)
152
 
153
  if request.stream:
154
- return EventSourceResponse(stream_replicate_native_sse(replicate_model_id, replicate_input))
155
 
156
- url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
 
157
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
158
 
159
  async with httpx.AsyncClient(timeout=150) as client:
160
  try:
161
- response = await client.post(url, headers=headers, json={"input": replicate_input})
162
  response.raise_for_status()
163
  prediction = response.json()
164
  output = "".join(prediction.get("output", []))
 
20
  # --- FastAPI App Initialization ---
21
  app = FastAPI(
22
  title="Replicate to OpenAI Compatibility Layer",
23
+ version="3.0.0 (Production Grade)",
24
  )
25
 
26
  # --- Pydantic Models ---
 
36
  class OpenAIChatCompletionRequest(BaseModel):
37
  model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
38
 
39
+ # --- Model Mapping with Explicit Version Hashes (Inspired by LiteLLM) ---
40
  SUPPORTED_MODELS = {
41
+ "llama3-8b-instruct": {
42
+ "id": "meta/meta-llama-3-8b-instruct",
43
+ "version": "02741d1be9a932e6566058d4c92ab80332f143003b5a874f63c9b743e4f3583c",
44
+ "input_type": "messages"
45
+ },
46
+ "claude-4.5-haiku": {
47
+ "id": "anthropic/claude-4.5-haiku",
48
+ "version": "311c5ff9b9f71c9ebd401b34a41ce604a8b735def3a4aad56f671302b5c56784",
49
+ "input_type": "prompt"
50
+ }
51
  }
52
 
53
  # --- Helper Functions ---
54
 
55
+ def build_replicate_request_body(request: OpenAIChatCompletionRequest, model_details: dict) -> dict:
56
+ """Builds the complete request body, including the crucial version hash."""
57
+ input_payload = {}
58
 
59
+ # Handle model-specific input format (prompt vs messages)
60
+ if model_details["input_type"] == "prompt":
61
  prompt_parts = []
62
  system_prompt = None
 
63
  for msg in request.messages:
64
  if msg.role == "system":
65
  system_prompt = str(msg.content)
66
  elif msg.role == "user":
67
+ prompt_parts.append(f"User: {msg.content}")
 
 
 
 
 
 
 
68
  elif msg.role == "assistant":
69
  prompt_parts.append(f"Assistant: {msg.content}")
70
+ prompt_parts.append("Assistant:") # Cue the model to respond
71
+ input_payload["prompt"] = "\n".join(prompt_parts)
72
+ if system_prompt: input_payload["system_prompt"] = system_prompt
73
+ else: # "messages"
74
+ input_payload["messages"] = [msg.dict() for msg in request.messages]
75
+
76
+ # Add common parameters
77
+ if request.max_tokens is not None: input_payload["max_new_tokens"] = request.max_tokens
78
+ if request.temperature is not None: input_payload["temperature"] = request.temperature
79
+ if request.top_p is not None: input_payload["top_p"] = request.top_p
80
+
81
+ return {
82
+ "version": model_details["version"],
83
+ "input": input_payload
84
+ }
85
+
86
+ async def stream_replicate_native_sse(model_id: str, request_body: dict):
87
+ """Connects to Replicate's native SSE stream for true token-by-token streaming."""
88
+ # Note: We call the generic predictions endpoint when providing a version hash.
89
+ url = "https://api.replicate.com/v1/predictions"
90
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
91
 
92
+ # Add stream=True to the request body
93
+ request_body["stream"] = True
94
+
95
  async with httpx.AsyncClient(timeout=300) as client:
96
  prediction = None
97
  try:
98
+ response = await client.post(url, headers=headers, json=request_body)
99
  response.raise_for_status()
100
  prediction = response.json()
101
  stream_url = prediction.get("urls", {}).get("stream")
 
118
  current_event = line[len("event:"):].strip()
119
  elif line.startswith("data:"):
120
  data = line[len("data:"):].strip()
 
121
  if current_event == "output":
 
 
 
122
  try:
123
  content = json.loads(data)
124
  chunk = {
 
127
  }
128
  yield json.dumps(chunk)
129
  except json.JSONDecodeError:
130
+ # Silently ignore malformed or empty data lines
131
  pass
132
  elif current_event == "done":
133
  break
 
152
  if model_key not in SUPPORTED_MODELS:
153
  raise HTTPException(status_code=404, detail=f"Model not found. Supported models: {list(SUPPORTED_MODELS.keys())}")
154
 
155
+ model_details = SUPPORTED_MODELS[model_key]
156
+ replicate_request_body = build_replicate_request_body(request, model_details)
157
 
158
  if request.stream:
159
+ return EventSourceResponse(stream_replicate_native_sse(model_details["id"], replicate_request_body))
160
 
161
+ # Synchronous request
162
+ url = "https://api.replicate.com/v1/predictions"
163
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
164
 
165
  async with httpx.AsyncClient(timeout=150) as client:
166
  try:
167
+ response = await client.post(url, headers=headers, json=replicate_request_body)
168
  response.raise_for_status()
169
  prediction = response.json()
170
  output = "".join(prediction.get("output", []))