rkihacker commited on
Commit
81706a7
·
verified ·
1 Parent(s): cb1bae0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -57
app.py CHANGED
@@ -5,95 +5,99 @@ import json
5
 
6
  app = Flask(__name__)
7
 
8
- # Target API base URL from environment variable
9
- TARGET_API = os.getenv("TARGET_API", "https://huggingface.co")
10
 
11
- # Path mappings from environment variable
12
- # Expected format: {"path1": "mapped_path1", "path2": "mapped_path2"}
 
 
 
 
 
 
 
 
 
 
13
  def get_path_mappings():
14
- mappings_str = os.getenv("PATH_MAPPINGS", '{"/": "/"}')
15
  try:
16
  return json.loads(mappings_str)
17
  except json.JSONDecodeError:
18
- # Fallback to default mappings if JSON is invalid
19
- return {
20
- "/": "/",
21
- }
22
 
23
  PATH_MAPPINGS = get_path_mappings()
24
 
25
 
26
  @app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
27
  def proxy(path):
28
- # Construct the full path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  full_path = f"/{path}"
30
 
31
- # Apply path mapping if matches
 
32
  for original_path, new_path in PATH_MAPPINGS.items():
33
- if full_path.startswith(original_path):
34
- full_path = full_path.replace(original_path, new_path, 1)
35
  break
36
 
37
- # Construct target URL
38
  target_url = f"{TARGET_API}{full_path}"
39
 
40
- # Forward the request to the target API
41
- headers = {key: value for key, value in request.headers if key != 'Host'}
 
 
 
 
42
 
43
- # Handle streaming response
44
- if request.method == 'POST':
45
- response = requests.post(
46
- target_url,
47
- headers=headers,
48
- json=request.get_json(silent=True),
49
- params=request.args,
50
- stream=True
51
- )
52
- elif request.method == 'GET':
53
- response = requests.get(
54
- target_url,
55
- headers=headers,
56
- params=request.args,
57
- stream=True
58
- )
59
- elif request.method == 'PUT':
60
- response = requests.put(
61
- target_url,
62
- headers=headers,
63
- json=request.get_json(silent=True),
64
- params=request.args,
65
- stream=True
66
- )
67
- elif request.method == 'DELETE':
68
- response = requests.delete(
69
- target_url,
70
- headers=headers,
71
- params=request.args,
72
- stream=True
73
- )
74
- elif request.method == 'PATCH':
75
- response = requests.patch(
76
- target_url,
77
  headers=headers,
78
- json=request.get_json(silent=True),
79
  params=request.args,
 
80
  stream=True
81
  )
 
 
 
 
82
 
83
- # Create a response with the same status code, headers, and streaming content
84
  def generate():
85
  for chunk in response.iter_content(chunk_size=8192):
86
  yield chunk
87
 
88
- # Create flask response
89
  proxy_response = Response(
90
  stream_with_context(generate()),
91
  status=response.status_code
92
  )
93
 
94
- # Forward response headers
 
 
95
  for key, value in response.headers.items():
96
- if key.lower() not in ('content-length', 'transfer-encoding', 'connection'):
97
  proxy_response.headers[key] = value
98
 
99
  return proxy_response
@@ -101,8 +105,9 @@ def proxy(path):
101
 
102
  @app.route('/', methods=['GET'])
103
  def index():
104
- return "service running."
105
 
106
 
107
  if __name__ == '__main__':
108
- app.run(host='0.0.0.0', port=7860, debug=False)
 
 
5
 
6
  app = Flask(__name__)
7
 
8
+ # --- Configuration from Environment Variables ---
 
9
 
10
+ # 1. Target API base URL
11
+ TARGET_API = os.getenv("TARGET_API", "https://api-inference.huggingface.co")
12
+
13
+ # 2. The REAL secret key for the target API. This is kept on the server.
14
+ REAL_AUTH_KEY = os.getenv("REAL_AUTH_KEY")
15
+
16
+ # 3. The access key for this proxy. This is what you share with your friends.
17
+ PROXY_ACCESS_KEY = os.getenv("PROXY_ACCESS_KEY")
18
+
19
+
20
+ # 4. Path mappings from environment variable
21
+ # Example: '{"/v1/chat/completions": "/models/mistralai/Mixtral-8x7B-Instruct-v0.1"}'
22
  def get_path_mappings():
23
+ mappings_str = os.getenv("PATH_MAPPINGS", '{}') # Default to empty dict
24
  try:
25
  return json.loads(mappings_str)
26
  except json.JSONDecodeError:
27
+ print("Warning: Invalid JSON in PATH_MAPPINGS. Using empty mappings.")
28
+ return {}
 
 
29
 
30
  PATH_MAPPINGS = get_path_mappings()
31
 
32
 
33
  @app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
34
  def proxy(path):
35
+ # --- 1. Authentication Check ---
36
+ # Ensure the server is configured with the necessary keys
37
+ if not REAL_AUTH_KEY or not PROXY_ACCESS_KEY:
38
+ error_msg = {"error": "Authentication is not configured on the proxy server."}
39
+ return Response(json.dumps(error_msg), status=500, mimetype='application/json')
40
+
41
+ # Get the authorization header from the user's request
42
+ auth_header = request.headers.get('Authorization')
43
+ expected_auth_header = f"Bearer {PROXY_ACCESS_KEY}"
44
+
45
+ # Validate the proxy access key
46
+ if auth_header != expected_auth_header:
47
+ error_msg = {"error": "Invalid or missing proxy access key."}
48
+ return Response(json.dumps(error_msg), status=401, mimetype='application/json')
49
+
50
+ # --- 2. Path and URL Construction ---
51
  full_path = f"/{path}"
52
 
53
+ # Apply path mapping if a match is found
54
+ # This allows you to map a generic path like /v1/chat to a specific model path
55
  for original_path, new_path in PATH_MAPPINGS.items():
56
+ if full_path == original_path:
57
+ full_path = new_path
58
  break
59
 
 
60
  target_url = f"{TARGET_API}{full_path}"
61
 
62
+ # --- 3. Header Manipulation ---
63
+ # Copy headers from the incoming request, but remove 'Host' and the user's 'Authorization'
64
+ headers = {key: value for key, value in request.headers if key.lower() not in ['host', 'authorization']}
65
+
66
+ # Add the REAL authentication key for the target API
67
+ headers['Authorization'] = f"Bearer {REAL_AUTH_KEY}"
68
 
69
+ # --- 4. Forward the Request (Refactored for all methods) ---
70
+ try:
71
+ # Use request.get_data() to handle any kind of request body (JSON, form data, etc.)
72
+ response = requests.request(
73
+ method=request.method,
74
+ url=target_url,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  headers=headers,
 
76
  params=request.args,
77
+ data=request.get_data(),
78
  stream=True
79
  )
80
+ except requests.exceptions.RequestException as e:
81
+ error_msg = {"error": f"Failed to connect to target service: {e}"}
82
+ return Response(json.dumps(error_msg), status=502, mimetype='application/json')
83
+
84
 
85
+ # --- 5. Stream the Response Back to the Client ---
86
  def generate():
87
  for chunk in response.iter_content(chunk_size=8192):
88
  yield chunk
89
 
90
+ # Create a Flask response object
91
  proxy_response = Response(
92
  stream_with_context(generate()),
93
  status=response.status_code
94
  )
95
 
96
+ # Copy headers from the target's response to our proxy response
97
+ # Exclude certain headers that are handled by the WSGI server
98
+ excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
99
  for key, value in response.headers.items():
100
+ if key.lower() not in excluded_headers:
101
  proxy_response.headers[key] = value
102
 
103
  return proxy_response
 
105
 
106
  @app.route('/', methods=['GET'])
107
  def index():
108
+ return "Proxy service is running."
109
 
110
 
111
  if __name__ == '__main__':
112
+ # It's recommended to run this with a production-grade WSGI server like Gunicorn or uWSGI
113
+ app.run(host='0.0.0.0', port=7860, debug=False)