Update main.py
Browse files
main.py
CHANGED
|
@@ -7,7 +7,6 @@ import random
|
|
| 7 |
import logging
|
| 8 |
import time
|
| 9 |
from contextlib import asynccontextmanager
|
| 10 |
-
import json
|
| 11 |
|
| 12 |
# --- Production-Ready Configuration ---
|
| 13 |
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
@@ -17,7 +16,7 @@ logging.basicConfig(
|
|
| 17 |
)
|
| 18 |
|
| 19 |
TARGET_URL = os.getenv("TARGET_URL", "https://api.gmi-serving.com")
|
| 20 |
-
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "
|
| 21 |
DEFAULT_RETRY_CODES = "429,500,502,503,504"
|
| 22 |
RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES)
|
| 23 |
try:
|
|
@@ -27,56 +26,11 @@ except ValueError:
|
|
| 27 |
logging.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
|
| 28 |
RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
|
| 29 |
|
| 30 |
-
# --- Helper
|
| 31 |
-
|
| 32 |
def generate_random_ip():
|
| 33 |
"""Generates a random, valid-looking IPv4 address."""
|
| 34 |
return ".".join(str(random.randint(1, 254)) for _ in range(4))
|
| 35 |
|
| 36 |
-
async def modified_aiter_raw(original_aiter):
|
| 37 |
-
"""
|
| 38 |
-
An async generator that intercepts and modifies the streaming data chunks.
|
| 39 |
-
It adds a prefix to the 'id' and includes a 'provider' field.
|
| 40 |
-
"""
|
| 41 |
-
buffer = ""
|
| 42 |
-
async for chunk in original_aiter:
|
| 43 |
-
buffer += chunk.decode('utf-8')
|
| 44 |
-
while '\n' in buffer:
|
| 45 |
-
line, buffer = buffer.split('\n', 1)
|
| 46 |
-
if line.startswith('data:'):
|
| 47 |
-
try:
|
| 48 |
-
# Strip the "data: " prefix to get the JSON string
|
| 49 |
-
json_str = line[len('data: '):].strip()
|
| 50 |
-
|
| 51 |
-
# Process only if it's not the SSE termination message
|
| 52 |
-
if json_str and json_str != '[DONE]':
|
| 53 |
-
data = json.loads(json_str)
|
| 54 |
-
|
| 55 |
-
# Add 'NAI-' prefix to the id
|
| 56 |
-
if 'id' in data:
|
| 57 |
-
data['id'] = f"NAI-{data['id']}"
|
| 58 |
-
|
| 59 |
-
# Add the provider field
|
| 60 |
-
data['provider'] = 'TypeGPT'
|
| 61 |
-
|
| 62 |
-
# Reconstruct the SSE data line
|
| 63 |
-
modified_line = f"data: {json.dumps(data)}"
|
| 64 |
-
yield (modified_line + '\n').encode('utf-8')
|
| 65 |
-
else:
|
| 66 |
-
# Pass through messages like 'data: [DONE]'
|
| 67 |
-
yield (line + '\n').encode('utf-8')
|
| 68 |
-
except json.JSONDecodeError:
|
| 69 |
-
# If it's not valid JSON, pass it through as is
|
| 70 |
-
yield (line + '\n').encode('utf-8')
|
| 71 |
-
else:
|
| 72 |
-
# Pass through non-data lines (e.g., empty lines, comments)
|
| 73 |
-
yield (line + '\n').encode('utf-8')
|
| 74 |
-
|
| 75 |
-
# Yield any remaining data in the buffer
|
| 76 |
-
if buffer:
|
| 77 |
-
yield buffer.encode('utf-8')
|
| 78 |
-
|
| 79 |
-
|
| 80 |
# --- HTTPX Client Lifecycle Management ---
|
| 81 |
@asynccontextmanager
|
| 82 |
async def lifespan(app: FastAPI):
|
|
@@ -143,8 +97,7 @@ async def reverse_proxy_handler(request: Request):
|
|
| 143 |
log_func(f"Request finished: {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
|
| 144 |
|
| 145 |
return StreamingResponse(
|
| 146 |
-
|
| 147 |
-
modified_aiter_raw(rp_resp.aiter_raw()),
|
| 148 |
status_code=rp_resp.status_code,
|
| 149 |
headers=rp_resp.headers,
|
| 150 |
background=BackgroundTask(rp_resp.aclose),
|
|
|
|
| 7 |
import logging
|
| 8 |
import time
|
| 9 |
from contextlib import asynccontextmanager
|
|
|
|
| 10 |
|
| 11 |
# --- Production-Ready Configuration ---
|
| 12 |
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
|
|
| 16 |
)
|
| 17 |
|
| 18 |
TARGET_URL = os.getenv("TARGET_URL", "https://api.gmi-serving.com")
|
| 19 |
+
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "12"))
|
| 20 |
DEFAULT_RETRY_CODES = "429,500,502,503,504"
|
| 21 |
RETRY_CODES_STR = os.getenv("RETRY_CODES", DEFAULT_RETRY_CODES)
|
| 22 |
try:
|
|
|
|
| 26 |
logging.error(f"Invalid RETRY_CODES format: '{RETRY_CODES_STR}'. Falling back to default: {DEFAULT_RETRY_CODES}")
|
| 27 |
RETRY_STATUS_CODES = {int(code.strip()) for code in DEFAULT_RETRY_CODES.split(',')}
|
| 28 |
|
| 29 |
+
# --- Helper Function ---
|
|
|
|
| 30 |
def generate_random_ip():
|
| 31 |
"""Generates a random, valid-looking IPv4 address."""
|
| 32 |
return ".".join(str(random.randint(1, 254)) for _ in range(4))
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# --- HTTPX Client Lifecycle Management ---
|
| 35 |
@asynccontextmanager
|
| 36 |
async def lifespan(app: FastAPI):
|
|
|
|
| 97 |
log_func(f"Request finished: {request.method} {request.url.path} status_code={rp_resp.status_code} latency={duration_ms:.2f}ms")
|
| 98 |
|
| 99 |
return StreamingResponse(
|
| 100 |
+
rp_resp.aiter_raw(),
|
|
|
|
| 101 |
status_code=rp_resp.status_code,
|
| 102 |
headers=rp_resp.headers,
|
| 103 |
background=BackgroundTask(rp_resp.aclose),
|