File size: 4,155 Bytes
e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 75257d5 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 e1a4f8a 81706a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from flask import Flask, request, Response, stream_with_context
import requests
import os
import json
app = Flask(__name__)
# --- Configuration from Environment Variables ---
# 1. Target API base URL
TARGET_API = os.getenv("TARGET_API", "https://api-inference.huggingface.co")
# 2. The REAL secret key for the target API. This is kept on the server.
REAL_AUTH_KEY = os.getenv("REAL_AUTH_KEY")
# 3. The access key for this proxy. This is what you share with your friends.
PROXY_ACCESS_KEY = os.getenv("PROXY_ACCESS_KEY")
# 4. Path mappings from environment variable
# Example: '{"/v1/chat/completions": "/models/mistralai/Mixtral-8x7B-Instruct-v0.1"}'
def get_path_mappings():
mappings_str = os.getenv("PATH_MAPPINGS", '{}') # Default to empty dict
try:
return json.loads(mappings_str)
except json.JSONDecodeError:
print("Warning: Invalid JSON in PATH_MAPPINGS. Using empty mappings.")
return {}
PATH_MAPPINGS = get_path_mappings()
@app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
def proxy(path):
# --- 1. Authentication Check ---
# Ensure the server is configured with the necessary keys
if not REAL_AUTH_KEY or not PROXY_ACCESS_KEY:
error_msg = {"error": "Authentication is not configured on the proxy server."}
return Response(json.dumps(error_msg), status=500, mimetype='application/json')
# Get the authorization header from the user's request
auth_header = request.headers.get('Authorization')
expected_auth_header = f"Bearer {PROXY_ACCESS_KEY}"
# Validate the proxy access key
if auth_header != expected_auth_header:
error_msg = {"error": "Invalid or missing proxy access key."}
return Response(json.dumps(error_msg), status=401, mimetype='application/json')
# --- 2. Path and URL Construction ---
full_path = f"/{path}"
# Apply path mapping if a match is found
# This allows you to map a generic path like /v1/chat to a specific model path
for original_path, new_path in PATH_MAPPINGS.items():
if full_path == original_path:
full_path = new_path
break
target_url = f"{TARGET_API}{full_path}"
# --- 3. Header Manipulation ---
# Copy headers from the incoming request, but remove 'Host' and the user's 'Authorization'
headers = {key: value for key, value in request.headers if key.lower() not in ['host', 'authorization']}
# Add the REAL authentication key for the target API
headers['Authorization'] = f"Bearer {REAL_AUTH_KEY}"
# --- 4. Forward the Request (Refactored for all methods) ---
try:
# Use request.get_data() to handle any kind of request body (JSON, form data, etc.)
response = requests.request(
method=request.method,
url=target_url,
headers=headers,
params=request.args,
data=request.get_data(),
stream=True
)
except requests.exceptions.RequestException as e:
error_msg = {"error": f"Failed to connect to target service: {e}"}
return Response(json.dumps(error_msg), status=502, mimetype='application/json')
# --- 5. Stream the Response Back to the Client ---
def generate():
for chunk in response.iter_content(chunk_size=1):
yield chunk
# Create a Flask response object
proxy_response = Response(
stream_with_context(generate()),
status=response.status_code
)
# Copy headers from the target's response to our proxy response
# Exclude certain headers that are handled by the WSGI server
excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
for key, value in response.headers.items():
if key.lower() not in excluded_headers:
proxy_response.headers[key] = value
return proxy_response
@app.route('/', methods=['GET'])
def index():
return "Proxy service is running."
if __name__ == '__main__':
# It's recommended to run this with a production-grade WSGI server like Gunicorn or uWSGI
app.run(host='0.0.0.0', port=7860, debug=False) |