Commit
·
be547ae
1
Parent(s):
da6c071
support multiple express keys
Browse files- app/config.py +5 -1
- app/routes/chat_api.py +22 -8
app/config.py
CHANGED
|
@@ -13,7 +13,11 @@ CREDENTIALS_DIR = os.environ.get("CREDENTIALS_DIR", "/app/credentials")
|
|
| 13 |
GOOGLE_CREDENTIALS_JSON_STR = os.environ.get("GOOGLE_CREDENTIALS_JSON")
|
| 14 |
|
| 15 |
# API Key for Vertex Express Mode
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Fake streaming settings for debugging/testing
|
| 19 |
FAKE_STREAMING_ENABLED = os.environ.get("FAKE_STREAMING", "false").lower() == "true"
|
|
|
|
| 13 |
GOOGLE_CREDENTIALS_JSON_STR = os.environ.get("GOOGLE_CREDENTIALS_JSON")
|
| 14 |
|
| 15 |
# API Key for Vertex Express Mode
|
| 16 |
+
raw_vertex_keys = os.environ.get("VERTEX_EXPRESS_API_KEY")
|
| 17 |
+
if raw_vertex_keys:
|
| 18 |
+
VERTEX_EXPRESS_API_KEY_VAL = [key.strip() for key in raw_vertex_keys.split(',') if key.strip()]
|
| 19 |
+
else:
|
| 20 |
+
VERTEX_EXPRESS_API_KEY_VAL = []
|
| 21 |
|
| 22 |
# Fake streaming settings for debugging/testing
|
| 23 |
FAKE_STREAMING_ENABLED = os.environ.get("FAKE_STREAMING", "false").lower() == "true"
|
app/routes/chat_api.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import asyncio
|
| 2 |
import json # Needed for error streaming
|
|
|
|
| 3 |
from fastapi import APIRouter, Depends, Request
|
| 4 |
from fastapi.responses import JSONResponse, StreamingResponse
|
| 5 |
from typing import List, Dict, Any
|
|
@@ -100,16 +101,29 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
| 100 |
generation_config = create_generation_config(request)
|
| 101 |
|
| 102 |
client_to_use = None
|
| 103 |
-
|
| 104 |
|
| 105 |
# Use dynamically fetched express models list for this check
|
| 106 |
-
if
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
if client_to_use is None:
|
| 115 |
rotated_credentials, rotated_project_id = credential_manager_instance.get_random_credentials()
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import json # Needed for error streaming
|
| 3 |
+
import random
|
| 4 |
from fastapi import APIRouter, Depends, Request
|
| 5 |
from fastapi.responses import JSONResponse, StreamingResponse
|
| 6 |
from typing import List, Dict, Any
|
|
|
|
| 101 |
generation_config = create_generation_config(request)
|
| 102 |
|
| 103 |
client_to_use = None
|
| 104 |
+
express_api_keys_list = app_config.VERTEX_EXPRESS_API_KEY_VAL
|
| 105 |
|
| 106 |
# Use dynamically fetched express models list for this check
|
| 107 |
+
if express_api_keys_list and base_model_name in vertex_express_model_ids: # Check against base_model_name
|
| 108 |
+
indexed_keys = list(enumerate(express_api_keys_list))
|
| 109 |
+
random.shuffle(indexed_keys)
|
| 110 |
+
|
| 111 |
+
for original_idx, key_val in indexed_keys:
|
| 112 |
+
try:
|
| 113 |
+
client_to_use = genai.Client(vertexai=True, api_key=key_val)
|
| 114 |
+
print(f"INFO: Using Vertex Express Mode for model {base_model_name} with API key (original index: {original_idx}).")
|
| 115 |
+
break # Successfully initialized client
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"WARNING: Vertex Express Mode client init failed for API key (original index: {original_idx}): {e}. Trying next key if available.")
|
| 118 |
+
client_to_use = None # Ensure client_to_use is None if this attempt fails
|
| 119 |
+
|
| 120 |
+
if client_to_use is None:
|
| 121 |
+
print(f"WARNING: All {len(express_api_keys_list)} Vertex Express API key(s) failed to initialize for model {base_model_name}. Falling back.")
|
| 122 |
+
# else:
|
| 123 |
+
# if not express_api_keys_list:
|
| 124 |
+
# print(f"DEBUG: No Vertex Express API keys configured. Skipping Express Mode attempt for model {base_model_name}.")
|
| 125 |
+
# elif base_model_name not in vertex_express_model_ids:
|
| 126 |
+
# print(f"DEBUG: Model {base_model_name} is not in the Vertex Express model list. Skipping Express Mode attempt.")
|
| 127 |
|
| 128 |
if client_to_use is None:
|
| 129 |
rotated_credentials, rotated_project_id = credential_manager_instance.get_random_credentials()
|