|
|
import os |
|
|
import httpx |
|
|
import json |
|
|
import time |
|
|
import asyncio |
|
|
import secrets |
|
|
from fastapi import FastAPI, HTTPException, Security, Depends, status |
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import List, Dict, Any, Optional, Union, Literal |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") |
|
|
SERVER_API_KEY = os.getenv("SERVER_API_KEY") |
|
|
|
|
|
if not REPLICATE_API_TOKEN: |
|
|
raise ValueError("REPLICATE_API_TOKEN environment variable not set.") |
|
|
if not SERVER_API_KEY: |
|
|
raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.") |
|
|
|
|
|
|
|
|
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.8 (Raw Output Fix)") |
|
|
|
|
|
|
|
|
security = HTTPBearer() |
|
|
|
|
|
async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)): |
|
|
""" |
|
|
Verify the API key provided in the Authorization header. |
|
|
""" |
|
|
if credentials.scheme != "Bearer" or credentials.credentials != SERVER_API_KEY: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="Invalid or missing API key", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
return True |
|
|
|
|
|
|
|
|
class ModelCard(BaseModel): |
|
|
id: str |
|
|
object: str = "model" |
|
|
created: int = Field(default_factory=lambda: int(time.time())) |
|
|
owned_by: str = "replicate" |
|
|
|
|
|
class ModelList(BaseModel): |
|
|
object: str = "list" |
|
|
data: List[ModelCard] = [] |
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
role: Literal["system", "user", "assistant", "tool"] |
|
|
content: Union[str, List[Dict[str, Any]]] |
|
|
name: Optional[str] = None |
|
|
tool_calls: Optional[List[Any]] = None |
|
|
|
|
|
class FunctionDefinition(BaseModel): |
|
|
name: str |
|
|
description: Optional[str] = None |
|
|
parameters: Optional[Dict[str, Any]] = None |
|
|
|
|
|
class ToolDefinition(BaseModel): |
|
|
type: Literal["function"] |
|
|
function: FunctionDefinition |
|
|
|
|
|
class FunctionCall(BaseModel): |
|
|
name: str |
|
|
arguments: str |
|
|
|
|
|
class ToolCall(BaseModel): |
|
|
id: str |
|
|
type: Literal["function"] = "function" |
|
|
function: FunctionCall |
|
|
|
|
|
class ChatCompletionRequest(BaseModel): |
|
|
model: str |
|
|
messages: List[ChatMessage] |
|
|
temperature: Optional[float] = 0.7 |
|
|
top_p: Optional[float] = 1.0 |
|
|
max_tokens: Optional[int] = None |
|
|
stream: Optional[bool] = False |
|
|
stop: Optional[Union[str, List[str]]] = None |
|
|
tools: Optional[List[ToolDefinition]] = None |
|
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None |
|
|
functions: Optional[List[FunctionDefinition]] = None |
|
|
function_call: Optional[Union[str, Dict[str, str]]] = None |
|
|
|
|
|
class Choice(BaseModel): |
|
|
index: int |
|
|
message: ChatMessage |
|
|
finish_reason: Optional[str] = None |
|
|
|
|
|
class Usage(BaseModel): |
|
|
prompt_tokens: int |
|
|
completion_tokens: int |
|
|
total_tokens: int |
|
|
inference_time: Optional[float] = None |
|
|
|
|
|
class ChatCompletion(BaseModel): |
|
|
id: str |
|
|
object: str = "chat.completion" |
|
|
created: int |
|
|
model: str |
|
|
choices: List[Choice] |
|
|
usage: Usage |
|
|
|
|
|
class DeltaMessage(BaseModel): |
|
|
role: Optional[str] = None |
|
|
content: Optional[str] = None |
|
|
tool_calls: Optional[List[ToolCall]] = None |
|
|
|
|
|
class ChoiceDelta(BaseModel): |
|
|
index: int |
|
|
delta: DeltaMessage |
|
|
finish_reason: Optional[str] = None |
|
|
|
|
|
class ChatCompletionChunk(BaseModel): |
|
|
id: str |
|
|
object: str = "chat.completion.chunk" |
|
|
created: int |
|
|
model: str |
|
|
choices: List[ChoiceDelta] |
|
|
usage: Optional[Usage] = None |
|
|
|
|
|
|
|
|
SUPPORTED_MODELS = { |
|
|
"llama3-8b-instruct": "meta/meta-llama-3-8b-instruct", |
|
|
"claude-4.5-haiku": "anthropic/claude-4.5-haiku", |
|
|
"claude-4.5-sonnet": "anthropic/claude-4.5-sonnet", |
|
|
"llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def generate_request_id() -> str: |
|
|
"""Generates a unique request ID in the user-specified format.""" |
|
|
return f"gen-{int(time.time())}-{secrets.token_hex(8)}" |
|
|
|
|
|
def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]: |
|
|
prompt_parts = [] |
|
|
system_prompt = None |
|
|
image_input = None |
|
|
|
|
|
if functions: |
|
|
functions_text = "You have access to the following tools. Use them if required to answer the user's question.\n\n" |
|
|
for func in functions: |
|
|
functions_text += f"- Function: {func.name}\n" |
|
|
if func.description: functions_text += f" Description: {func.description}\n" |
|
|
if func.parameters: functions_text += f" Parameters: {json.dumps(func.parameters)}\n" |
|
|
prompt_parts.append(functions_text) |
|
|
|
|
|
for msg in messages: |
|
|
if msg.role == "system": |
|
|
system_prompt = str(msg.content) |
|
|
elif msg.role == "assistant": |
|
|
if msg.tool_calls: |
|
|
tool_calls_text = "\nTool calls:\n" |
|
|
for tool_call in msg.tool_calls: |
|
|
tool_calls_text += f"- {tool_call.function.name}({tool_call.function.arguments})\n" |
|
|
prompt_parts.append(f"Assistant: {tool_calls_text}") |
|
|
else: |
|
|
prompt_parts.append(f"Assistant: {msg.content}") |
|
|
elif msg.role == "tool": |
|
|
prompt_parts.append(f"Tool Response: {msg.content}") |
|
|
elif msg.role == "user": |
|
|
user_text_content = "" |
|
|
if isinstance(msg.content, list): |
|
|
for item in msg.content: |
|
|
if item.get("type") == "text": |
|
|
user_text_content += item.get("text", "") |
|
|
elif item.get("type") == "image_url": |
|
|
image_url_data = item.get("image_url", {}) |
|
|
image_input = image_url_data.get("url") |
|
|
else: |
|
|
user_text_content = str(msg.content) |
|
|
prompt_parts.append(f"User: {user_text_content}") |
|
|
|
|
|
prompt_parts.append("Assistant:") |
|
|
return { |
|
|
"prompt": "\n\n".join(prompt_parts), |
|
|
"system_prompt": system_prompt, |
|
|
"image": image_input |
|
|
} |
|
|
|
|
|
def parse_function_call(content: str) -> Optional[Dict[str, Any]]: |
|
|
try: |
|
|
if "function_call" in content or ("name" in content and "arguments" in content): |
|
|
start = content.find("{") |
|
|
end = content.rfind("}") + 1 |
|
|
if start != -1 and end > start: |
|
|
json_str = content[start:end] |
|
|
parsed = json.loads(json_str) |
|
|
if "name" in parsed and "arguments" in parsed: |
|
|
return parsed |
|
|
except (json.JSONDecodeError, Exception): |
|
|
pass |
|
|
return None |
|
|
|
|
|
async def stream_replicate_response(replicate_model_id: str, input_payload: dict, request_id: str): |
|
|
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions" |
|
|
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"} |
|
|
|
|
|
start_time = time.time() |
|
|
prompt_tokens = len(input_payload.get("prompt", "")) // 4 |
|
|
completion_tokens = 0 |
|
|
|
|
|
async with httpx.AsyncClient(timeout=300.0) as client: |
|
|
try: |
|
|
response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True}) |
|
|
response.raise_for_status() |
|
|
prediction = response.json() |
|
|
stream_url = prediction.get("urls", {}).get("stream") |
|
|
if not stream_url: |
|
|
yield f"data: {json.dumps({'error': {'message': 'Model did not return a stream URL.'}})}\n\n" |
|
|
return |
|
|
except httpx.HTTPStatusError as e: |
|
|
error_details = e.response.text |
|
|
try: error_details = e.response.json().get("detail", error_details) |
|
|
except json.JSONDecodeError: pass |
|
|
yield f"data: {json.dumps({'error': {'message': f'Upstream Error: {error_details}', 'type': 'replicate_error'}})}\n\n" |
|
|
return |
|
|
|
|
|
try: |
|
|
async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse: |
|
|
current_event = None |
|
|
accumulated_content = "" |
|
|
|
|
|
async for line in sse.aiter_lines(): |
|
|
if not line: continue |
|
|
|
|
|
if line.startswith("event:"): |
|
|
current_event = line[len("event:"):].strip() |
|
|
elif line.startswith("data:") and current_event == "output": |
|
|
raw_data = line[5:].strip() |
|
|
if not raw_data: continue |
|
|
|
|
|
try: |
|
|
content_token = json.loads(raw_data) |
|
|
except (json.JSONDecodeError, TypeError): |
|
|
content_token = raw_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accumulated_content += content_token |
|
|
completion_tokens += 1 |
|
|
|
|
|
function_call = parse_function_call(accumulated_content) |
|
|
if function_call: |
|
|
tool_call = ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=function_call["arguments"])) |
|
|
chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(tool_calls=[tool_call]), finish_reason=None)]) |
|
|
yield f"data: {chunk.json()}\n\n" |
|
|
else: |
|
|
if content_token: |
|
|
chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token), finish_reason=None)]) |
|
|
yield f"data: {chunk.json()}\n\n" |
|
|
|
|
|
elif current_event == "done": |
|
|
end_time = time.time() |
|
|
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, inference_time=round(end_time - start_time, 3)) |
|
|
usage_chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(), finish_reason="stop")], usage=usage) |
|
|
yield f"data: {usage_chunk.json()}\n\n" |
|
|
break |
|
|
|
|
|
except httpx.ReadTimeout: |
|
|
yield f"data: {json.dumps({'error': {'message': 'Stream timed out.', 'type': 'timeout_error'}})}\n\n" |
|
|
return |
|
|
|
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
|
|
|
@app.get("/v1/models", dependencies=[Depends(verify_api_key)]) |
|
|
async def list_models(): |
|
|
""" |
|
|
Protected endpoint to list available models. |
|
|
""" |
|
|
return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()]) |
|
|
|
|
|
@app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)]) |
|
|
async def create_chat_completion(request: ChatCompletionRequest): |
|
|
""" |
|
|
Protected endpoint to create a chat completion. |
|
|
""" |
|
|
if request.model not in SUPPORTED_MODELS: |
|
|
raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}") |
|
|
|
|
|
replicate_model_id = SUPPORTED_MODELS[request.model] |
|
|
formatted = format_messages_for_replicate(request.messages, request.functions) |
|
|
|
|
|
replicate_input = { |
|
|
"prompt": formatted["prompt"], |
|
|
"temperature": request.temperature or 0.7, |
|
|
"top_p": request.top_p or 1.0 |
|
|
} |
|
|
|
|
|
if request.max_tokens is not None: |
|
|
replicate_input["max_new_tokens"] = request.max_tokens |
|
|
|
|
|
if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"] |
|
|
if formatted["image"]: replicate_input["image"] = formatted["image"] |
|
|
|
|
|
request_id = generate_request_id() |
|
|
|
|
|
if request.stream: |
|
|
return StreamingResponse( |
|
|
stream_replicate_response(replicate_model_id, replicate_input, request_id), |
|
|
media_type="text/event-stream" |
|
|
) |
|
|
|
|
|
|
|
|
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions" |
|
|
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"} |
|
|
start_time = time.time() |
|
|
|
|
|
async with httpx.AsyncClient() as client: |
|
|
try: |
|
|
resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=300.0) |
|
|
resp.raise_for_status() |
|
|
pred = resp.json() |
|
|
|
|
|
|
|
|
raw_output = pred.get("output") |
|
|
|
|
|
if isinstance(raw_output, list): |
|
|
output = "".join(raw_output) |
|
|
elif isinstance(raw_output, str): |
|
|
output = raw_output |
|
|
else: |
|
|
output = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end_time = time.time() |
|
|
prompt_tokens = len(replicate_input.get("prompt", "")) // 4 |
|
|
completion_tokens = len(output) // 4 |
|
|
|
|
|
tool_calls = None |
|
|
finish_reason = "stop" |
|
|
message_content = output |
|
|
|
|
|
function_call = parse_function_call(output) |
|
|
if function_call: |
|
|
tool_calls = [ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=function_call["arguments"]))] |
|
|
finish_reason = "tool_calls" |
|
|
message_content = None |
|
|
|
|
|
return ChatCompletion( |
|
|
id=request_id, |
|
|
created=int(time.time()), |
|
|
model=request.model, |
|
|
choices=[Choice( |
|
|
index=0, |
|
|
message=ChatMessage(role="assistant", content=message_content, tool_calls=tool_calls), |
|
|
finish_reason=finish_reason |
|
|
)], |
|
|
usage=Usage( |
|
|
prompt_tokens=prompt_tokens, |
|
|
completion_tokens=completion_tokens, |
|
|
total_tokens=prompt_tokens + completion_tokens, |
|
|
inference_time=round(end_time - start_time, 3) |
|
|
) |
|
|
) |
|
|
except httpx.HTTPStatusError as e: |
|
|
raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}") |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
""" |
|
|
Root endpoint for health checks. Does not require authentication. |
|
|
""" |
|
|
return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.8"} |
|
|
|
|
|
@app.middleware("http") |
|
|
async def add_performance_headers(request, call_next): |
|
|
start_time = time.time() |
|
|
response = await call_next(request) |
|
|
process_time = time.time() - start_time |
|
|
response.headers["X-Process-Time"] = str(round(process_time, 3)) |
|
|
response.headers["X-API-Version"] = "9.2.8" |
|
|
return response |
|
|
|
|
|
|