rkihacker commited on
Commit
580cccc
·
verified ·
1 Parent(s): a9bb1ec

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +15 -5
main.py CHANGED
@@ -5,7 +5,7 @@ import json
5
  import time
6
  import asyncio
7
  from fastapi import FastAPI, HTTPException
8
- from fastapi.responses import StreamingResponse, JSONResponse
9
  from pydantic import BaseModel, Field
10
  from typing import List, Dict, Any, Optional, Union, Literal
11
  from dotenv import load_dotenv
@@ -17,7 +17,7 @@ if not REPLICATE_API_TOKEN:
17
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
18
 
19
  # FastAPI Init
20
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.0 (Full OpenAI Compatibility)")
21
 
22
  # --- Pydantic Models ---
23
  class ModelCard(BaseModel):
@@ -106,6 +106,7 @@ def format_messages_for_replicate(messages: List[ChatMessage], functions: Option
106
  user_text_content = str(msg.content)
107
  prompt_parts.append(f"User: {user_text_content}")
108
 
 
109
  prompt_parts.append("Assistant:")
110
  return {
111
  "prompt": "\n\n".join(prompt_parts),
@@ -161,6 +162,7 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
161
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
162
  current_event = None
163
  accumulated_content = ""
 
164
 
165
  async for line in sse.aiter_lines():
166
  if not line: continue
@@ -177,6 +179,11 @@ async def stream_replicate_response(replicate_model_id: str, input_payload: dict
177
  except (json.JSONDecodeError, TypeError):
178
  content_token = raw_data
179
 
 
 
 
 
 
180
  accumulated_content += content_token
181
  completion_tokens += 1
182
 
@@ -291,6 +298,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
291
  pred = resp.json()
292
  output = "".join(pred.get("output", []))
293
 
 
 
 
294
  # Calculate timing and tokens
295
  end_time = time.time()
296
  inference_time = end_time - start_time
@@ -337,7 +347,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
337
 
338
  @app.get("/")
339
  async def root():
340
- return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.0"}
341
 
342
  # Performance optimization middleware
343
  @app.middleware("http")
@@ -346,5 +356,5 @@ async def add_performance_headers(request, call_next):
346
  response = await call_next(request)
347
  process_time = time.time() - start_time
348
  response.headers["X-Process-Time"] = str(round(process_time, 3))
349
- response.headers["X-API-Version"] = "9.2.0"
350
- return response
 
5
  import time
6
  import asyncio
7
  from fastapi import FastAPI, HTTPException
8
+ from fastapi.responses import StreamingResponse
9
  from pydantic import BaseModel, Field
10
  from typing import List, Dict, Any, Optional, Union, Literal
11
  from dotenv import load_dotenv
 
17
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
18
 
19
  # FastAPI Init
20
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.2.1 (Spacing Fixed)")
21
 
22
  # --- Pydantic Models ---
23
  class ModelCard(BaseModel):
 
106
  user_text_content = str(msg.content)
107
  prompt_parts.append(f"User: {user_text_content}")
108
 
109
+ # Fix: Don't add trailing space, let model decide spacing
110
  prompt_parts.append("Assistant:")
111
  return {
112
  "prompt": "\n\n".join(prompt_parts),
 
162
  async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
163
  current_event = None
164
  accumulated_content = ""
165
+ first_token = True
166
 
167
  async for line in sse.aiter_lines():
168
  if not line: continue
 
179
  except (json.JSONDecodeError, TypeError):
180
  content_token = raw_data
181
 
182
+ # Fix: Handle spacing properly - don't prepend space to first token
183
+ if first_token:
184
+ content_token = content_token.lstrip()
185
+ first_token = False
186
+
187
  accumulated_content += content_token
188
  completion_tokens += 1
189
 
 
298
  pred = resp.json()
299
  output = "".join(pred.get("output", []))
300
 
301
+ # Fix: Clean up leading/trailing whitespace
302
+ output = output.strip()
303
+
304
  # Calculate timing and tokens
305
  end_time = time.time()
306
  inference_time = end_time - start_time
 
347
 
348
  @app.get("/")
349
  async def root():
350
+ return {"message": "Replicate to OpenAI Compatibility Layer API", "version": "9.2.1"}
351
 
352
  # Performance optimization middleware
353
  @app.middleware("http")
 
356
  response = await call_next(request)
357
  process_time = time.time() - start_time
358
  response.headers["X-Process-Time"] = str(round(process_time, 3))
359
+ response.headers["X-API-Version"] = "9.2.1"
360
+ return response