File size: 16,148 Bytes
89b138d
 
 
 
a9bb1ec
5d3f475
9f14d65
 
580cccc
89b138d
 
 
 
415ec30
89b138d
 
b236837
9f14d65
89b138d
 
9f14d65
 
89b138d
415ec30
67e24f8
9f14d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89b138d
2fc646f
89b138d
3830a6b
 
 
 
 
89b138d
3830a6b
 
 
89b138d
3830a6b
67e24f8
3830a6b
 
 
a9bb1ec
3830a6b
 
 
 
a9bb1ec
3830a6b
 
 
a9bb1ec
3830a6b
 
 
a9bb1ec
3830a6b
 
 
 
a9bb1ec
3830a6b
 
 
 
 
 
 
 
 
 
 
a9bb1ec
 
3830a6b
 
 
 
a9bb1ec
3830a6b
 
 
 
 
a9bb1ec
3830a6b
 
 
 
 
 
a9bb1ec
 
3830a6b
 
 
 
a9bb1ec
3830a6b
 
 
 
a9bb1ec
3830a6b
 
 
 
 
 
2fc646f
415ec30
89b138d
415ec30
67e24f8
 
c466862
89b138d
 
415ec30
b236837
 
67e24f8
 
b236837
a9bb1ec
a135be4
 
 
 
a9bb1ec
67e24f8
a9bb1ec
67e24f8
 
 
 
a9bb1ec
 
a135be4
 
 
0f99721
a9bb1ec
 
 
 
 
 
 
 
a135be4
 
 
 
 
 
 
 
 
c466862
a135be4
 
415ec30
67e24f8
a9bb1ec
67e24f8
a9bb1ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63b36c2
67e24f8
415ec30
1120bba
89b138d
2767573
a9bb1ec
2767573
 
a9bb1ec
89b138d
415ec30
89b138d
 
ea53c08
 
a9bb1ec
54de3fd
89b138d
54de3fd
0f99721
54de3fd
a9bb1ec
54de3fd
dafbe9c
e014ad9
c5a8085
e014ad9
67e24f8
a9bb1ec
e014ad9
a9bb1ec
 
e014ad9
 
a9bb1ec
 
 
0f99721
a9bb1ec
 
 
67e24f8
a9bb1ec
67e24f8
 
 
 
 
5e2bd86
67e24f8
 
5e2bd86
67e24f8
 
 
 
 
 
 
 
 
 
a9bb1ec
 
0f99721
67e24f8
 
a9bb1ec
 
e014ad9
a9bb1ec
e014ad9
 
a9bb1ec
7b0c05f
415ec30
9f14d65
89b138d
67e24f8
 
 
415ec30
89b138d
9f14d65
a9bb1ec
67e24f8
 
 
415ec30
e014ad9
89b138d
0f99721
a9bb1ec
b236837
a9bb1ec
 
67e24f8
 
a9bb1ec
b236837
 
 
 
0f99721
 
a9bb1ec
b236837
a9bb1ec
89b138d
a9bb1ec
67e24f8
a9bb1ec
 
415ec30
a9bb1ec
0f99721
a9bb1ec
2767573
 
415ec30
e014ad9
a9bb1ec
e014ad9
 
5d3f475
67e24f8
5d3f475
 
67e24f8
 
 
 
 
 
 
 
 
 
 
 
2767573
a9bb1ec
 
 
 
0f99721
 
 
a9bb1ec
 
67e24f8
0f99721
67e24f8
2767573
a9bb1ec
 
 
 
 
 
0f99721
 
a9bb1ec
 
 
 
 
0f99721
a9bb1ec
 
e014ad9
dafbe9c
a9bb1ec
 
 
 
 
67e24f8
 
 
 
a9bb1ec
 
 
 
 
 
 
67e24f8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import os
import httpx
import json
import time
import asyncio
import secrets
from fastapi import FastAPI, HTTPException, Security, Depends, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Union, Literal
from dotenv import load_dotenv

# Load environment variables
load_dotenv()
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
SERVER_API_KEY = os.getenv("SERVER_API_KEY") # <-- Key for server auth

if not REPLICATE_API_TOKEN:
    raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
if not SERVER_API_KEY:
    raise ValueError("SERVER_API_KEY environment variable not set. This is required to protect your server.")

# FastAPI Init
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.8 (Raw Output Fix)")

# --- Authentication ---
security = HTTPBearer()

async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
    """
    Verify the API key provided in the Authorization header.
    """
    if credentials.scheme != "Bearer" or credentials.credentials != SERVER_API_KEY:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid or missing API key",
            headers={"WWW-Authenticate": "Bearer"},
        )
    return True

# --- Pydantic Models ---
class ModelCard(BaseModel):
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
    owned_by: str = "replicate"

class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelCard] = []

class ChatMessage(BaseModel):
    role: Literal["system", "user", "assistant", "tool"]
    content: Union[str, List[Dict[str, Any]]]
    name: Optional[str] = None
    tool_calls: Optional[List[Any]] = None

class FunctionDefinition(BaseModel):
    name: str
    description: Optional[str] = None
    parameters: Optional[Dict[str, Any]] = None

class ToolDefinition(BaseModel):
    type: Literal["function"]
    function: FunctionDefinition

class FunctionCall(BaseModel):
    name: str
    arguments: str

class ToolCall(BaseModel):
    id: str
    type: Literal["function"] = "function"
    function: FunctionCall

class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 1.0
    max_tokens: Optional[int] = None
    stream: Optional[bool] = False
    stop: Optional[Union[str, List[str]]] = None
    tools: Optional[List[ToolDefinition]] = None
    tool_choice: Optional[Union[str, Dict[str, Any]]] = None
    functions: Optional[List[FunctionDefinition]] = None
    function_call: Optional[Union[str, Dict[str, str]]] = None

class Choice(BaseModel):
    index: int
    message: ChatMessage
    finish_reason: Optional[str] = None

class Usage(BaseModel):
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int
    inference_time: Optional[float] = None

class ChatCompletion(BaseModel):
    id: str
    object: str = "chat.completion"
    created: int
    model: str
    choices: List[Choice]
    usage: Usage

class DeltaMessage(BaseModel):
    role: Optional[str] = None
    content: Optional[str] = None
    tool_calls: Optional[List[ToolCall]] = None

class ChoiceDelta(BaseModel):
    index: int
    delta: DeltaMessage
    finish_reason: Optional[str] = None

class ChatCompletionChunk(BaseModel):
    id: str
    object: str = "chat.completion.chunk"
    created: int
    model: str
    choices: List[ChoiceDelta]
    usage: Optional[Usage] = None

# --- Supported Models ---
SUPPORTED_MODELS = {
    "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
    "claude-4.5-haiku": "anthropic/claude-4.5-haiku",
    "claude-4.5-sonnet": "anthropic/claude-4.5-sonnet",
    "llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358"
}

# --- Core Logic ---

def generate_request_id() -> str:
    """Generates a unique request ID in the user-specified format."""
    return f"gen-{int(time.time())}-{secrets.token_hex(8)}"

def format_messages_for_replicate(messages: List[ChatMessage], functions: Optional[List[FunctionDefinition]] = None) -> Dict[str, Any]:
    prompt_parts = []
    system_prompt = None
    image_input = None
    
    if functions:
        functions_text = "You have access to the following tools. Use them if required to answer the user's question.\n\n"
        for func in functions:
            functions_text += f"- Function: {func.name}\n"
            if func.description: functions_text += f"  Description: {func.description}\n"
            if func.parameters: functions_text += f"  Parameters: {json.dumps(func.parameters)}\n"
        prompt_parts.append(functions_text)

    for msg in messages:
        if msg.role == "system":
            system_prompt = str(msg.content)
        elif msg.role == "assistant":
            if msg.tool_calls:
                tool_calls_text = "\nTool calls:\n"
                for tool_call in msg.tool_calls:
                    tool_calls_text += f"- {tool_call.function.name}({tool_call.function.arguments})\n"
                prompt_parts.append(f"Assistant: {tool_calls_text}")
            else:
                prompt_parts.append(f"Assistant: {msg.content}")
        elif msg.role == "tool":
            prompt_parts.append(f"Tool Response: {msg.content}")
        elif msg.role == "user":
            user_text_content = ""
            if isinstance(msg.content, list):
                for item in msg.content:
                    if item.get("type") == "text":
                        user_text_content += item.get("text", "")
                    elif item.get("type") == "image_url":
                        image_url_data = item.get("image_url", {})
                        image_input = image_url_data.get("url")
            else:
                user_text_content = str(msg.content)
            prompt_parts.append(f"User: {user_text_content}")

    prompt_parts.append("Assistant:") # Let the model generate the space after this
    return {
        "prompt": "\n\n".join(prompt_parts),
        "system_prompt": system_prompt,
        "image": image_input
    }

def parse_function_call(content: str) -> Optional[Dict[str, Any]]:
    try:
        if "function_call" in content or ("name" in content and "arguments" in content):
            start = content.find("{")
            end = content.rfind("}") + 1
            if start != -1 and end > start:
                json_str = content[start:end]
                parsed = json.loads(json_str)
                if "name" in parsed and "arguments" in parsed:
                    return parsed
    except (json.JSONDecodeError, Exception):
        pass
    return None

async def stream_replicate_response(replicate_model_id: str, input_payload: dict, request_id: str):
    url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
    headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
    
    start_time = time.time()
    prompt_tokens = len(input_payload.get("prompt", "")) // 4
    completion_tokens = 0
    
    async with httpx.AsyncClient(timeout=300.0) as client:
        try:
            response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
            response.raise_for_status()
            prediction = response.json()
            stream_url = prediction.get("urls", {}).get("stream")
            if not stream_url:
                yield f"data: {json.dumps({'error': {'message': 'Model did not return a stream URL.'}})}\n\n"
                return
        except httpx.HTTPStatusError as e:
            error_details = e.response.text
            try: error_details = e.response.json().get("detail", error_details)
            except json.JSONDecodeError: pass
            yield f"data: {json.dumps({'error': {'message': f'Upstream Error: {error_details}', 'type': 'replicate_error'}})}\n\n"
            return

        try:
            async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
                current_event = None
                accumulated_content = ""
                
                async for line in sse.aiter_lines():
                    if not line: continue
                    
                    if line.startswith("event:"):
                        current_event = line[len("event:"):].strip()
                    elif line.startswith("data:") and current_event == "output":
                        raw_data = line[5:].strip()
                        if not raw_data: continue

                        try:
                            content_token = json.loads(raw_data)
                        except (json.JSONDecodeError, TypeError):
                            content_token = raw_data
                        
                        # ### THIS IS THE FIX ###
                        # There is NO lstrip() or strip() here.
                        # This sends the raw, unmodified token from Replicate.
                        # If the log shows "HowcanI", it's because the model
                        # sent "How", "can", "I" as separate tokens.
                        
                        accumulated_content += content_token
                        completion_tokens += 1
                        
                        function_call = parse_function_call(accumulated_content)
                        if function_call:
                            tool_call = ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=function_call["arguments"]))
                            chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(tool_calls=[tool_call]), finish_reason=None)])
                            yield f"data: {chunk.json()}\n\n"
                        else:
                            if content_token:
                                chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(content=content_token), finish_reason=None)])
                                yield f"data: {chunk.json()}\n\n"
                            
                    elif current_event == "done":
                        end_time = time.time()
                        usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, inference_time=round(end_time - start_time, 3))
                        usage_chunk = ChatCompletionChunk(id=request_id, created=int(time.time()), model=replicate_model_id, choices=[ChoiceDelta(index=0, delta=DeltaMessage(), finish_reason="stop")], usage=usage)
                        yield f"data: {usage_chunk.json()}\n\n"
                        break
                        
        except httpx.ReadTimeout:
            yield f"data: {json.dumps({'error': {'message': 'Stream timed out.', 'type': 'timeout_error'}})}\n\n"
            return

    yield "data: [DONE]\n\n"

# --- Endpoints ---
@app.get("/v1/models", dependencies=[Depends(verify_api_key)])
async def list_models():
    """
    Protected endpoint to list available models.
    """
    return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])

@app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
async def create_chat_completion(request: ChatCompletionRequest):
    """
    Protected endpoint to create a chat completion.
    """
    if request.model not in SUPPORTED_MODELS:
        raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}")
    
    replicate_model_id = SUPPORTED_MODELS[request.model]
    formatted = format_messages_for_replicate(request.messages, request.functions)
    
    replicate_input = {
        "prompt": formatted["prompt"],
        "temperature": request.temperature or 0.7,
        "top_p": request.top_p or 1.0
    }
    
    if request.max_tokens is not None:
        replicate_input["max_new_tokens"] = request.max_tokens
    
    if formatted["system_prompt"]: replicate_input["system_prompt"] = formatted["system_prompt"]
    if formatted["image"]: replicate_input["image"] = formatted["image"]
    
    request_id = generate_request_id()
    
    if request.stream:
        return StreamingResponse(
            stream_replicate_response(replicate_model_id, replicate_input, request_id),
            media_type="text/event-stream"
        )

    # Non-streaming response
    url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
    headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
    start_time = time.time()
    
    async with httpx.AsyncClient() as client:
        try:
            resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=300.0)
            resp.raise_for_status()
            pred = resp.json()

            # Handle the 'output' field which could be a list, string, or null
            raw_output = pred.get("output")

            if isinstance(raw_output, list):
                output = "".join(raw_output)  # Expected case: list of strings
            elif isinstance(raw_output, str):
                output = raw_output          # Handle if it's just a single string
            else:
                output = ""                  
            
            # ### THIS IS THE FIX ###
            # Removed output.strip() to return the raw response.
            # This fixes the bug where a single space (" ") response
            # would become "" and show content: "" in the JSON.
            
            end_time = time.time()
            prompt_tokens = len(replicate_input.get("prompt", "")) // 4
            completion_tokens = len(output) // 4
            
            tool_calls = None
            finish_reason = "stop"
            message_content = output
            
            function_call = parse_function_call(output)
            if function_call:
                tool_calls = [ToolCall(id=f"call_{int(time.time())}", function=FunctionCall(name=function_call["name"], arguments=function_call["arguments"]))]
                finish_reason = "tool_calls"
                message_content = None # OpenAI standard: content is null when tool_calls are present
            
            return ChatCompletion(
                id=request_id,
                created=int(time.time()),
                model=request.model,
                choices=[Choice(
                    index=0,
                    message=ChatMessage(role="assistant", content=message_content, tool_calls=tool_calls),
                    finish_reason=finish_reason
                )],
                usage=Usage(
                    prompt_tokens=prompt_tokens,
                    completion_tokens=completion_tokens,
                    total_tokens=prompt_tokens + completion_tokens,
                    inference_time=round(end_time - start_time, 3)
                )
            )
        except httpx.HTTPStatusError as e:
            raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")

@app.get("/")
async def root():
    """
    Root endpoint for health checks. Does not require authentication.
    """
    return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.8"}

@app.middleware("http")
async def add_performance_headers(request, call_next):
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = str(round(process_time, 3))
    response.headers["X-API-Version"] = "9.2.8"
    return response