Spaces:
Sleeping
Sleeping
Commit
·
1e622b4
1
Parent(s):
135f3ac
Add retry logic upon schema fail for Mistral API calls
Browse files
utils.py
CHANGED
|
@@ -190,7 +190,7 @@ def llm_stream_serverless(prompt,model):
|
|
| 190 |
LAST_REQUEST_TIME = None
|
| 191 |
REQUEST_INTERVAL = 0.5 # Minimum time interval between requests in seconds
|
| 192 |
|
| 193 |
-
def llm_stream_mistral_api(prompt: str, pydantic_model_class) -> Union[str, Dict[str, Any]]:
|
| 194 |
global LAST_REQUEST_TIME
|
| 195 |
current_time = time()
|
| 196 |
if LAST_REQUEST_TIME is not None:
|
|
@@ -227,10 +227,24 @@ def llm_stream_mistral_api(prompt: str, pydantic_model_class) -> Union[str, Dict
|
|
| 227 |
print(result)
|
| 228 |
output = result['choices'][0]['message']['content']
|
| 229 |
if pydantic_model_class:
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
#
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
else:
|
| 235 |
print("No pydantic model class provided, returning without class validation")
|
| 236 |
return json.loads(output)
|
|
|
|
| 190 |
LAST_REQUEST_TIME = None
|
| 191 |
REQUEST_INTERVAL = 0.5 # Minimum time interval between requests in seconds
|
| 192 |
|
| 193 |
+
def llm_stream_mistral_api(prompt: str, pydantic_model_class=None, attempts=0) -> Union[str, Dict[str, Any]]:
|
| 194 |
global LAST_REQUEST_TIME
|
| 195 |
current_time = time()
|
| 196 |
if LAST_REQUEST_TIME is not None:
|
|
|
|
| 227 |
print(result)
|
| 228 |
output = result['choices'][0]['message']['content']
|
| 229 |
if pydantic_model_class:
|
| 230 |
+
# TODO: Use more robust error handling that works for all cases without retrying?
|
| 231 |
+
# Maybe APIs that dont have grammar should be avoided?
|
| 232 |
+
# Investigate grammar enforcement with open ended generations?
|
| 233 |
+
try:
|
| 234 |
+
parsed_result = pydantic_model_class.model_validate_json(output)
|
| 235 |
+
print(parsed_result)
|
| 236 |
+
# This will raise an exception if the model is invalid,
|
| 237 |
+
except Exception as e:
|
| 238 |
+
print(f"Error validating pydantic model: {e}")
|
| 239 |
+
# Let's retry by calling ourselves again if attempts < 3
|
| 240 |
+
if attempts == 0:
|
| 241 |
+
# We modify the prompt to remind it to output JSON in the required format
|
| 242 |
+
prompt = f"{prompt} You must output the JSON in the required format!"
|
| 243 |
+
if attempts < 3:
|
| 244 |
+
attempts += 1
|
| 245 |
+
print(f"Retrying Mistral API call, attempt {attempts}")
|
| 246 |
+
return llm_stream_mistral_api(prompt, pydantic_model_class, attempts)
|
| 247 |
+
|
| 248 |
else:
|
| 249 |
print("No pydantic model class provided, returning without class validation")
|
| 250 |
return json.loads(output)
|