|
|
import os |
|
|
import httpx |
|
|
import json |
|
|
import time |
|
|
import asyncio |
|
|
from fastapi import FastAPI, Request, HTTPException |
|
|
from fastapi.responses import JSONResponse |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import List, Dict, Any, Optional, Union, Literal |
|
|
from dotenv import load_dotenv |
|
|
from sse_starlette.sse import EventSourceResponse |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") |
|
|
if not REPLICATE_API_TOKEN: |
|
|
raise ValueError("REPLICATE_API_TOKEN environment variable not set.") |
|
|
|
|
|
POLLING_INTERVAL_SECONDS = 1 |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Replicate to OpenAI Compatibility Layer", |
|
|
version="1.1.0 (Polling Strategy)", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelCard(BaseModel): |
|
|
id: str |
|
|
object: str = "model" |
|
|
created: int = Field(default_factory=lambda: int(time.time())) |
|
|
owned_by: str = "replicate" |
|
|
|
|
|
class ModelList(BaseModel): |
|
|
object: str = "list" |
|
|
data: List[ModelCard] = [] |
|
|
|
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
role: Literal["system", "user", "assistant", "tool"] |
|
|
content: Union[str, List[Dict[str, Any]]] |
|
|
|
|
|
class ToolFunction(BaseModel): |
|
|
name: str |
|
|
description: str |
|
|
parameters: Dict[str, Any] |
|
|
|
|
|
class Tool(BaseModel): |
|
|
type: Literal["function"] |
|
|
function: ToolFunction |
|
|
|
|
|
class OpenAIChatCompletionRequest(BaseModel): |
|
|
model: str |
|
|
messages: List[ChatMessage] |
|
|
temperature: Optional[float] = 0.7 |
|
|
top_p: Optional[float] = 1.0 |
|
|
max_tokens: Optional[int] = None |
|
|
stream: Optional[bool] = False |
|
|
tools: Optional[List[Tool]] = None |
|
|
tool_choice: Optional[Union[str, Dict]] = None |
|
|
|
|
|
|
|
|
SUPPORTED_MODELS = { |
|
|
"llama3-8b-instruct": "meta/meta-llama-3-8b-instruct", |
|
|
"claude-4.5-haiku": "anthropic/claude-4.5-haiku" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_tools_for_prompt(tools: List[Tool]) -> str: |
|
|
"""Converts OpenAI tools to a string for the system prompt.""" |
|
|
if not tools: |
|
|
return "" |
|
|
prompt = "You have access to the following tools. To use a tool, respond with a JSON object in the following format:\n" |
|
|
prompt += '{"type": "tool_call", "name": "tool_name", "arguments": {"arg_name": "value"}}\n\n" |
|
|
prompt += "Available tools:\n" |
|
|
for tool in tools: |
|
|
prompt += json.dumps(tool.function.dict(), indent=2) + "\n" |
|
|
return prompt |
|
|
|
|
|
def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]: |
|
|
"""Prepares the input payload for the Replicate API.""" |
|
|
input_data = {} |
|
|
prompt_parts = [] |
|
|
system_prompt = "" |
|
|
image_url = None |
|
|
|
|
|
for message in request.messages: |
|
|
if message.role == "system": |
|
|
system_prompt += str(message.content) + "\n" |
|
|
elif message.role == "user": |
|
|
content = message.content |
|
|
if isinstance(content, list): |
|
|
for item in content: |
|
|
if item.get("type") == "text": |
|
|
prompt_parts.append(f"User: {item.get('text', '')}") |
|
|
elif item.get("type") == "image_url": |
|
|
image_url = item.get("image_url", {}).get("url") |
|
|
else: |
|
|
prompt_parts.append(f"User: {str(content)}") |
|
|
elif message.role == "assistant": |
|
|
prompt_parts.append(f"Assistant: {str(message.content)}") |
|
|
|
|
|
if request.tools: |
|
|
tool_prompt = format_tools_for_prompt(request.tools) |
|
|
system_prompt += "\n" + tool_prompt |
|
|
|
|
|
input_data["prompt"] = "\n".join(prompt_parts) |
|
|
if system_prompt: |
|
|
input_data["system_prompt"] = system_prompt |
|
|
if image_url: |
|
|
input_data["image"] = image_url |
|
|
|
|
|
if request.temperature is not None: |
|
|
input_data["temperature"] = request.temperature |
|
|
if request.top_p is not None: |
|
|
input_data["top_p"] = request.top_p |
|
|
if request.max_tokens is not None: |
|
|
input_data["max_new_tokens"] = request.max_tokens |
|
|
|
|
|
return input_data |
|
|
|
|
|
async def stream_replicate_with_polling(model_id: str, payload: dict): |
|
|
""" |
|
|
Creates a prediction and then polls the 'get' URL to stream back results. |
|
|
This is a reliable alternative to Replicate's native SSE stream. |
|
|
""" |
|
|
url = f"https://api.replicate.com/v1/models/{model_id}/predictions" |
|
|
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"} |
|
|
|
|
|
async with httpx.AsyncClient(timeout=300) as client: |
|
|
# 1. Start the prediction |
|
|
try: |
|
|
response = await client.post(url, headers=headers, json={"input": payload}) |
|
|
response.raise_for_status() |
|
|
prediction = response.json() |
|
|
get_url = prediction.get("urls", {}).get("get") |
|
|
|
|
|
if not get_url: |
|
|
error_detail = prediction.get("detail", "Failed to start prediction.") |
|
|
yield f"data: {json.dumps({'error': error_detail})}\n\n" |
|
|
return |
|
|
except httpx.HTTPStatusError as e: |
|
|
yield f"data: {json.dumps({'error': str(e.response.text)})}\n\n" |
|
|
return |
|
|
|
|
|
# 2. Poll the prediction 'get' URL for updates |
|
|
previous_output = "" |
|
|
status = "" |
|
|
while status not in ["succeeded", "failed", "canceled"]: |
|
|
await asyncio.sleep(POLLING_INTERVAL_SECONDS) |
|
|
try: |
|
|
poll_response = await client.get(get_url, headers=headers) |
|
|
poll_response.raise_for_status() |
|
|
prediction_update = poll_response.json() |
|
|
status = prediction_update["status"] |
|
|
|
|
|
if status == "failed": |
|
|
error_detail = prediction_update.get("error", "Prediction failed.") |
|
|
yield f"data: {json.dumps({'error': error_detail})}\n\n" |
|
|
break |
|
|
|
|
|
if "output" in prediction_update and prediction_update["output"] is not None: |
|
|
current_output = "".join(prediction_update["output"]) |
|
|
new_chunk = current_output[len(previous_output):] |
|
|
|
|
|
if new_chunk: |
|
|
chunk = { |
|
|
"id": prediction["id"], |
|
|
"object": "chat.completion.chunk", |
|
|
"created": int(time.time()), |
|
|
"model": model_id, |
|
|
"choices": [{"index": 0, "delta": {"content": new_chunk}, "finish_reason": None}] |
|
|
} |
|
|
yield f"data: {json.dumps(chunk)}\n\n" |
|
|
previous_output = current_output |
|
|
|
|
|
except httpx.HTTPStatusError as e: |
|
|
# Don't stop polling on temporary network errors |
|
|
print(f"Warning: Polling failed with status {e.response.status_code}, retrying...") |
|
|
except Exception as e: |
|
|
yield f"data: {json.dumps({'error': f'Polling error: {str(e)}'})}\n\n" |
|
|
break |
|
|
|
|
|
# Send the final done signal |
|
|
done_chunk = { |
|
|
"id": prediction["id"], |
|
|
"object": "chat.completion.chunk", |
|
|
"created": int(time.time()), |
|
|
"model": model_id, |
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop" if status == "succeeded" else "error"}] |
|
|
} |
|
|
yield f"data: {json.dumps(done_chunk)}\n\n" |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
|
|
|
# --- API Endpoints --- |
|
|
|
|
|
@app.get("/v1/models", response_model=ModelList) |
|
|
async def list_models(): |
|
|
"""Lists the available models.""" |
|
|
model_cards = [ModelCard(id=model_name) for model_name in SUPPORTED_MODELS.keys()] |
|
|
return ModelList(data=model_cards) |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def create_chat_completion(request: OpenAIChatCompletionRequest): |
|
|
"""Creates a chat completion.""" |
|
|
model_key = request.model |
|
|
if model_key not in SUPPORTED_MODELS: |
|
|
raise HTTPException(status_code=404, detail=f"Model not found. Supported models: {list(SUPPORTED_MODELS.keys())}") |
|
|
|
|
|
replicate_model_id = SUPPORTED_MODELS[model_key] |
|
|
replicate_input = prepare_replicate_input(request) |
|
|
|
|
|
if request.stream: |
|
|
# Use the new reliable polling-based streamer |
|
|
return EventSourceResponse(stream_replicate_with_polling(replicate_model_id, replicate_input)) |
|
|
|
|
|
# Synchronous request (no changes needed here) |
|
|
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions" |
|
|
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"} |
|
|
|
|
|
async with httpx.AsyncClient(timeout=150) as client: |
|
|
try: |
|
|
response = await client.post(url, headers=headers, json={"input": replicate_input}) |
|
|
response.raise_for_status() |
|
|
prediction = response.json() |
|
|
|
|
|
output = prediction.get("output", "") |
|
|
if isinstance(output, list): |
|
|
output = "".join(output) |
|
|
|
|
|
# Basic tool call detection |
|
|
try: |
|
|
tool_call_data = json.loads(output) |
|
|
if tool_call_data.get("type") == "tool_call": |
|
|
message_content, tool_calls = None, [{"id": f"call_{int(time.time())}", "type": "function", "function": {"name": tool_call_data["name"], "arguments": json.dumps(tool_call_data["arguments"])}}] |
|
|
else: |
|
|
message_content, tool_calls = output, None |
|
|
except (json.JSONDecodeError, TypeError): |
|
|
message_content, tool_calls = output, None |
|
|
|
|
|
completion_response = { |
|
|
"id": prediction["id"], |
|
|
"object": "chat.completion", |
|
|
"created": int(time.time()), |
|
|
"model": model_key, |
|
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": message_content, "tool_calls": tool_calls}, "finish_reason": "stop"}], |
|
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} |
|
|
} |
|
|
return JSONResponse(content=completion_response) |
|
|
|
|
|
except httpx.HTTPStatusError as e: |
|
|
raise HTTPException(status_code=e.response.status_code, detail=e.response.text) |