rkihacker commited on
Commit
f8a669a
·
verified ·
1 Parent(s): 2fc646f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +10 -11
main.py CHANGED
@@ -20,7 +20,7 @@ if not REPLICATE_API_TOKEN:
20
  # --- FastAPI App Initialization ---
21
  app = FastAPI(
22
  title="Replicate to OpenAI Compatibility Layer",
23
- version="2.1.0 (Model Input Fixed)",
24
  )
25
 
26
  # --- Pydantic Models ---
@@ -45,14 +45,9 @@ SUPPORTED_MODELS = {
45
  # --- Helper Functions ---
46
 
47
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
48
- """
49
- Prepares the input payload for Replicate, handling model-specific formats.
50
- """
51
  payload = {}
52
 
53
- # *** THIS IS THE CRITICAL FIX ***
54
- # Claude models on Replicate require a single 'prompt' string.
55
- # We must convert the 'messages' array into a formatted string.
56
  if "claude" in request.model:
57
  prompt_parts = []
58
  system_prompt = None
@@ -72,6 +67,9 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
72
  prompt_parts.append(f"User: {msg.content}")
73
  elif msg.role == "assistant":
74
  prompt_parts.append(f"Assistant: {msg.content}")
 
 
 
75
 
76
  payload["prompt"] = "\n".join(prompt_parts)
77
  if system_prompt:
@@ -79,11 +77,9 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
79
  if image_url:
80
  payload["image"] = image_url
81
 
82
- # Other models like Llama-3 accept the 'messages' array directly.
83
- else:
84
  payload["messages"] = [msg.dict() for msg in request.messages]
85
 
86
- # Add common parameters
87
  if request.max_tokens is not None:
88
  payload["max_new_tokens"] = request.max_tokens
89
  if request.temperature is not None:
@@ -126,7 +122,10 @@ async def stream_replicate_native_sse(model_id: str, payload: dict):
126
  current_event = line[len("event:"):].strip()
127
  elif line.startswith("data:"):
128
  data = line[len("data:"):].strip()
129
- if current_event == "output":
 
 
 
130
  chunk = {
131
  "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
132
  "choices": [{"index": 0, "delta": {"content": json.loads(data)}, "finish_reason": None}]
 
20
  # --- FastAPI App Initialization ---
21
  app = FastAPI(
22
  title="Replicate to OpenAI Compatibility Layer",
23
+ version="2.2.0 (Stable Streaming)",
24
  )
25
 
26
  # --- Pydantic Models ---
 
45
  # --- Helper Functions ---
46
 
47
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
48
+ """Prepares the input payload for Replicate, handling model-specific formats."""
 
 
49
  payload = {}
50
 
 
 
 
51
  if "claude" in request.model:
52
  prompt_parts = []
53
  system_prompt = None
 
67
  prompt_parts.append(f"User: {msg.content}")
68
  elif msg.role == "assistant":
69
  prompt_parts.append(f"Assistant: {msg.content}")
70
+
71
+ # Add final turn for the assistant to respond
72
+ prompt_parts.append("Assistant:")
73
 
74
  payload["prompt"] = "\n".join(prompt_parts)
75
  if system_prompt:
 
77
  if image_url:
78
  payload["image"] = image_url
79
 
80
+ else: # Llama-3 and other standard chat models
 
81
  payload["messages"] = [msg.dict() for msg in request.messages]
82
 
 
83
  if request.max_tokens is not None:
84
  payload["max_new_tokens"] = request.max_tokens
85
  if request.temperature is not None:
 
122
  current_event = line[len("event:"):].strip()
123
  elif line.startswith("data:"):
124
  data = line[len("data:"):].strip()
125
+
126
+ # *** THIS IS THE CRITICAL FIX ***
127
+ # Only process non-empty data for 'output' events
128
+ if data and current_event == "output":
129
  chunk = {
130
  "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
131
  "choices": [{"index": 0, "delta": {"content": json.loads(data)}, "finish_reason": None}]