aks / app.py
rkihacker's picture
Update app.py
75257d5 verified
from flask import Flask, request, Response, stream_with_context
import requests
import os
import json
app = Flask(__name__)
# --- Configuration from Environment Variables ---
# 1. Target API base URL
TARGET_API = os.getenv("TARGET_API", "https://api-inference.huggingface.co")
# 2. The REAL secret key for the target API. This is kept on the server.
REAL_AUTH_KEY = os.getenv("REAL_AUTH_KEY")
# 3. The access key for this proxy. This is what you share with your friends.
PROXY_ACCESS_KEY = os.getenv("PROXY_ACCESS_KEY")
# 4. Path mappings from environment variable
# Example: '{"/v1/chat/completions": "/models/mistralai/Mixtral-8x7B-Instruct-v0.1"}'
def get_path_mappings():
mappings_str = os.getenv("PATH_MAPPINGS", '{}') # Default to empty dict
try:
return json.loads(mappings_str)
except json.JSONDecodeError:
print("Warning: Invalid JSON in PATH_MAPPINGS. Using empty mappings.")
return {}
PATH_MAPPINGS = get_path_mappings()
@app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
def proxy(path):
# --- 1. Authentication Check ---
# Ensure the server is configured with the necessary keys
if not REAL_AUTH_KEY or not PROXY_ACCESS_KEY:
error_msg = {"error": "Authentication is not configured on the proxy server."}
return Response(json.dumps(error_msg), status=500, mimetype='application/json')
# Get the authorization header from the user's request
auth_header = request.headers.get('Authorization')
expected_auth_header = f"Bearer {PROXY_ACCESS_KEY}"
# Validate the proxy access key
if auth_header != expected_auth_header:
error_msg = {"error": "Invalid or missing proxy access key."}
return Response(json.dumps(error_msg), status=401, mimetype='application/json')
# --- 2. Path and URL Construction ---
full_path = f"/{path}"
# Apply path mapping if a match is found
# This allows you to map a generic path like /v1/chat to a specific model path
for original_path, new_path in PATH_MAPPINGS.items():
if full_path == original_path:
full_path = new_path
break
target_url = f"{TARGET_API}{full_path}"
# --- 3. Header Manipulation ---
# Copy headers from the incoming request, but remove 'Host' and the user's 'Authorization'
headers = {key: value for key, value in request.headers if key.lower() not in ['host', 'authorization']}
# Add the REAL authentication key for the target API
headers['Authorization'] = f"Bearer {REAL_AUTH_KEY}"
# --- 4. Forward the Request (Refactored for all methods) ---
try:
# Use request.get_data() to handle any kind of request body (JSON, form data, etc.)
response = requests.request(
method=request.method,
url=target_url,
headers=headers,
params=request.args,
data=request.get_data(),
stream=True
)
except requests.exceptions.RequestException as e:
error_msg = {"error": f"Failed to connect to target service: {e}"}
return Response(json.dumps(error_msg), status=502, mimetype='application/json')
# --- 5. Stream the Response Back to the Client ---
def generate():
for chunk in response.iter_content(chunk_size=1):
yield chunk
# Create a Flask response object
proxy_response = Response(
stream_with_context(generate()),
status=response.status_code
)
# Copy headers from the target's response to our proxy response
# Exclude certain headers that are handled by the WSGI server
excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
for key, value in response.headers.items():
if key.lower() not in excluded_headers:
proxy_response.headers[key] = value
return proxy_response
@app.route('/', methods=['GET'])
def index():
return "Proxy service is running."
if __name__ == '__main__':
# It's recommended to run this with a production-grade WSGI server like Gunicorn or uWSGI
app.run(host='0.0.0.0', port=7860, debug=False)