Spaces:
Runtime error
Runtime error
Commit
Β·
64d7e88
1
Parent(s):
daa4f11
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,6 +10,8 @@ from time import sleep
|
|
| 10 |
|
| 11 |
import inspect
|
| 12 |
|
|
|
|
|
|
|
| 13 |
from random import randint
|
| 14 |
|
| 15 |
from urllib.parse import quote
|
|
@@ -179,6 +181,7 @@ class RavenDemo(gr.Blocks):
|
|
| 179 |
self.summary_model_client = InferenceClient(config.summary_model_endpoint)
|
| 180 |
|
| 181 |
self.max_num_steps = 20
|
|
|
|
| 182 |
|
| 183 |
with self:
|
| 184 |
gr.HTML(HEADER_HTML)
|
|
@@ -299,6 +302,10 @@ class RavenDemo(gr.Blocks):
|
|
| 299 |
*steps,
|
| 300 |
)
|
| 301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
user_input = gr.Textbox(interactive=False)
|
| 303 |
raven_function_call = ""
|
| 304 |
summary_model_summary = ""
|
|
@@ -307,7 +314,8 @@ class RavenDemo(gr.Blocks):
|
|
| 307 |
gmaps_html = ""
|
| 308 |
steps_accordion = gr.Accordion(open=True)
|
| 309 |
steps = [gr.Textbox(value="", visible=False) for _ in range(self.max_num_steps)]
|
| 310 |
-
|
|
|
|
| 311 |
|
| 312 |
raven_prompt = self.functions_helper.get_prompt(
|
| 313 |
query.replace("'", r"\'").replace('"', r"\"")
|
|
@@ -328,7 +336,18 @@ class RavenDemo(gr.Blocks):
|
|
| 328 |
r_calls = [c.strip() for c in raven_function_call.split(";") if c.strip()]
|
| 329 |
f_r_calls = []
|
| 330 |
for r_c in r_calls:
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
f_r_calls.append(f_r_call)
|
| 333 |
|
| 334 |
raven_function_call = "; ".join(f_r_calls)
|
|
@@ -424,6 +443,21 @@ class RavenDemo(gr.Blocks):
|
|
| 424 |
user_input = gr.Textbox(interactive=True, autofocus=False)
|
| 425 |
yield get_returns()
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
def get_summary_model_prompt(self, results: List, query: str) -> None:
|
| 428 |
# TODO check what outputs are returned and return them properly
|
| 429 |
ALLOWED_KEYS = [
|
|
|
|
| 10 |
|
| 11 |
import inspect
|
| 12 |
|
| 13 |
+
import ast
|
| 14 |
+
|
| 15 |
from random import randint
|
| 16 |
|
| 17 |
from urllib.parse import quote
|
|
|
|
| 181 |
self.summary_model_client = InferenceClient(config.summary_model_endpoint)
|
| 182 |
|
| 183 |
self.max_num_steps = 20
|
| 184 |
+
self.function_call_name_set = set([f.name for f in FUNCTIONS])
|
| 185 |
|
| 186 |
with self:
|
| 187 |
gr.HTML(HEADER_HTML)
|
|
|
|
| 302 |
*steps,
|
| 303 |
)
|
| 304 |
|
| 305 |
+
def on_error():
|
| 306 |
+
initial_return[0] = gr.Textbox(interactive=True, autofocus=False)
|
| 307 |
+
return initial_return
|
| 308 |
+
|
| 309 |
user_input = gr.Textbox(interactive=False)
|
| 310 |
raven_function_call = ""
|
| 311 |
summary_model_summary = ""
|
|
|
|
| 314 |
gmaps_html = ""
|
| 315 |
steps_accordion = gr.Accordion(open=True)
|
| 316 |
steps = [gr.Textbox(value="", visible=False) for _ in range(self.max_num_steps)]
|
| 317 |
+
initial_return = list(get_returns())
|
| 318 |
+
yield initial_return
|
| 319 |
|
| 320 |
raven_prompt = self.functions_helper.get_prompt(
|
| 321 |
query.replace("'", r"\'").replace('"', r"\"")
|
|
|
|
| 336 |
r_calls = [c.strip() for c in raven_function_call.split(";") if c.strip()]
|
| 337 |
f_r_calls = []
|
| 338 |
for r_c in r_calls:
|
| 339 |
+
try:
|
| 340 |
+
f_r_call = format_str(r_c.strip(), mode=Mode())
|
| 341 |
+
except:
|
| 342 |
+
yield on_error()
|
| 343 |
+
gr.Warning(ERROR_MESSAGE)
|
| 344 |
+
return
|
| 345 |
+
|
| 346 |
+
if not self.whitelist_function_names(f_r_call):
|
| 347 |
+
yield on_error()
|
| 348 |
+
gr.Warning(ERROR_MESSAGE)
|
| 349 |
+
return
|
| 350 |
+
|
| 351 |
f_r_calls.append(f_r_call)
|
| 352 |
|
| 353 |
raven_function_call = "; ".join(f_r_calls)
|
|
|
|
| 443 |
user_input = gr.Textbox(interactive=True, autofocus=False)
|
| 444 |
yield get_returns()
|
| 445 |
|
| 446 |
+
def whitelist_function_names(self, function_call_str: str) -> bool:
|
| 447 |
+
"""
|
| 448 |
+
Defensive function name whitelisting inspired by @evan-nexusflow
|
| 449 |
+
"""
|
| 450 |
+
for expr in ast.walk(ast.parse(function_call_str)):
|
| 451 |
+
if not isinstance(expr, ast.Call):
|
| 452 |
+
continue
|
| 453 |
+
|
| 454 |
+
expr: ast.Call
|
| 455 |
+
function_name = expr.func.id
|
| 456 |
+
if function_name not in self.function_call_name_set:
|
| 457 |
+
return False
|
| 458 |
+
|
| 459 |
+
return True
|
| 460 |
+
|
| 461 |
def get_summary_model_prompt(self, results: List, query: str) -> None:
|
| 462 |
# TODO check what outputs are returned and return them properly
|
| 463 |
ALLOWED_KEYS = [
|