Update main.py
Browse files
main.py
CHANGED
|
@@ -23,12 +23,11 @@ POLLING_INTERVAL_SECONDS = 1 # How often to poll for updates
|
|
| 23 |
# --- FastAPI App Initialization ---
|
| 24 |
app = FastAPI(
|
| 25 |
title="Replicate to OpenAI Compatibility Layer",
|
| 26 |
-
version="1.1.
|
| 27 |
)
|
| 28 |
|
| 29 |
-
# --- Pydantic Models for OpenAI Compatibility
|
| 30 |
|
| 31 |
-
# /v1/models endpoint
|
| 32 |
class ModelCard(BaseModel):
|
| 33 |
id: str
|
| 34 |
object: str = "model"
|
|
@@ -39,7 +38,6 @@ class ModelList(BaseModel):
|
|
| 39 |
object: str = "list"
|
| 40 |
data: List[ModelCard] = []
|
| 41 |
|
| 42 |
-
# /v1/chat/completions endpoint
|
| 43 |
class ChatMessage(BaseModel):
|
| 44 |
role: Literal["system", "user", "assistant", "tool"]
|
| 45 |
content: Union[str, List[Dict[str, Any]]]
|
|
@@ -76,8 +74,10 @@ def format_tools_for_prompt(tools: List[Tool]) -> str:
|
|
| 76 |
"""Converts OpenAI tools to a string for the system prompt."""
|
| 77 |
if not tools:
|
| 78 |
return ""
|
|
|
|
| 79 |
prompt = "You have access to the following tools. To use a tool, respond with a JSON object in the following format:\n"
|
| 80 |
-
|
|
|
|
| 81 |
prompt += "Available tools:\n"
|
| 82 |
for tool in tools:
|
| 83 |
prompt += json.dumps(tool.function.dict(), indent=2) + "\n"
|
|
@@ -128,7 +128,6 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
|
|
| 128 |
async def stream_replicate_with_polling(model_id: str, payload: dict):
|
| 129 |
"""
|
| 130 |
Creates a prediction and then polls the 'get' URL to stream back results.
|
| 131 |
-
This is a reliable alternative to Replicate's native SSE stream.
|
| 132 |
"""
|
| 133 |
url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
|
| 134 |
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
|
|
@@ -181,7 +180,6 @@ async def stream_replicate_with_polling(model_id: str, payload: dict):
|
|
| 181 |
previous_output = current_output
|
| 182 |
|
| 183 |
except httpx.HTTPStatusError as e:
|
| 184 |
-
# Don't stop polling on temporary network errors
|
| 185 |
print(f"Warning: Polling failed with status {e.response.status_code}, retrying...")
|
| 186 |
except Exception as e:
|
| 187 |
yield f"data: {json.dumps({'error': f'Polling error: {str(e)}'})}\n\n"
|
|
@@ -218,10 +216,9 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
|
|
| 218 |
replicate_input = prepare_replicate_input(request)
|
| 219 |
|
| 220 |
if request.stream:
|
| 221 |
-
# Use the new reliable polling-based streamer
|
| 222 |
return EventSourceResponse(stream_replicate_with_polling(replicate_model_id, replicate_input))
|
| 223 |
|
| 224 |
-
# Synchronous request
|
| 225 |
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
|
| 226 |
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
|
| 227 |
|
|
|
|
| 23 |
# --- FastAPI App Initialization ---
|
| 24 |
app = FastAPI(
|
| 25 |
title="Replicate to OpenAI Compatibility Layer",
|
| 26 |
+
version="1.1.1 (SyntaxError Fixed)",
|
| 27 |
)
|
| 28 |
|
| 29 |
+
# --- Pydantic Models for OpenAI Compatibility ---
|
| 30 |
|
|
|
|
| 31 |
class ModelCard(BaseModel):
|
| 32 |
id: str
|
| 33 |
object: str = "model"
|
|
|
|
| 38 |
object: str = "list"
|
| 39 |
data: List[ModelCard] = []
|
| 40 |
|
|
|
|
| 41 |
class ChatMessage(BaseModel):
|
| 42 |
role: Literal["system", "user", "assistant", "tool"]
|
| 43 |
content: Union[str, List[Dict[str, Any]]]
|
|
|
|
| 74 |
"""Converts OpenAI tools to a string for the system prompt."""
|
| 75 |
if not tools:
|
| 76 |
return ""
|
| 77 |
+
|
| 78 |
prompt = "You have access to the following tools. To use a tool, respond with a JSON object in the following format:\n"
|
| 79 |
+
# *** THIS IS THE CORRECTED LINE ***
|
| 80 |
+
prompt += '{"type": "tool_call", "name": "tool_name", "arguments": {"arg_name": "value"}}\n\n'
|
| 81 |
prompt += "Available tools:\n"
|
| 82 |
for tool in tools:
|
| 83 |
prompt += json.dumps(tool.function.dict(), indent=2) + "\n"
|
|
|
|
| 128 |
async def stream_replicate_with_polling(model_id: str, payload: dict):
|
| 129 |
"""
|
| 130 |
Creates a prediction and then polls the 'get' URL to stream back results.
|
|
|
|
| 131 |
"""
|
| 132 |
url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
|
| 133 |
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
|
|
|
|
| 180 |
previous_output = current_output
|
| 181 |
|
| 182 |
except httpx.HTTPStatusError as e:
|
|
|
|
| 183 |
print(f"Warning: Polling failed with status {e.response.status_code}, retrying...")
|
| 184 |
except Exception as e:
|
| 185 |
yield f"data: {json.dumps({'error': f'Polling error: {str(e)}'})}\n\n"
|
|
|
|
| 216 |
replicate_input = prepare_replicate_input(request)
|
| 217 |
|
| 218 |
if request.stream:
|
|
|
|
| 219 |
return EventSourceResponse(stream_replicate_with_polling(replicate_model_id, replicate_input))
|
| 220 |
|
| 221 |
+
# Synchronous request
|
| 222 |
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
|
| 223 |
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
|
| 224 |
|