Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
·
f8c72d3
1
Parent(s):
dac4e7d
Use session state for LangChain graph
Browse files
app.py
CHANGED
|
@@ -18,55 +18,58 @@ import os
|
|
| 18 |
# Setup environment variables
|
| 19 |
load_dotenv(dotenv_path=".env", override=True)
|
| 20 |
|
| 21 |
-
# Global
|
| 22 |
-
COMPUTE = "local"
|
| 23 |
search_type = "hybrid"
|
| 24 |
|
| 25 |
-
# Global variables for LangChain graph
|
| 26 |
-
|
| 27 |
-
|
| 28 |
|
| 29 |
|
| 30 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
"""The main function to run the chat workflow"""
|
| 32 |
|
| 33 |
-
#
|
| 34 |
-
|
| 35 |
-
if COMPUTE == "local":
|
| 36 |
-
# We don't want the app to switch into remote mode without notification,
|
| 37 |
-
# so ask the user to do it
|
| 38 |
if not torch.cuda.is_available():
|
| 39 |
raise gr.Error(
|
| 40 |
"Local mode requires GPU. Please select remote mode.",
|
| 41 |
print_exception=False,
|
| 42 |
)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
|
| 47 |
if graph is None:
|
| 48 |
# Notify when we're loading the local model because it takes some time
|
| 49 |
-
if
|
| 50 |
gr.Info(
|
| 51 |
f"Please wait for the local model to load",
|
| 52 |
duration=15,
|
| 53 |
title=f"Model loading...",
|
| 54 |
)
|
| 55 |
# Get the chat model and build the graph
|
| 56 |
-
chat_model = GetChatModel(
|
| 57 |
-
graph_builder = BuildGraph(chat_model,
|
| 58 |
# Compile the graph with an in-memory checkpointer
|
| 59 |
memory = MemorySaver()
|
| 60 |
graph = graph_builder.compile(checkpointer=memory)
|
| 61 |
# Set global graph for compute mode
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
# Notify when model finishes loading
|
| 68 |
-
gr.Success(f"{COMPUTE}", duration=4, title=f"Model loaded!")
|
| 69 |
-
print(f"Set graph for {COMPUTE}, {search_type}!")
|
| 70 |
|
| 71 |
print(f"Using thread_id: {thread_id}")
|
| 72 |
|
|
@@ -180,13 +183,16 @@ def run_workflow(input, history, thread_id):
|
|
| 180 |
yield history, None, citations
|
| 181 |
|
| 182 |
|
| 183 |
-
def to_workflow(*args):
|
| 184 |
"""Wrapper function to call function with or without @spaces.GPU"""
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
| 187 |
yield value
|
| 188 |
-
if
|
| 189 |
-
for value in run_workflow_remote(*
|
| 190 |
yield value
|
| 191 |
|
| 192 |
|
|
@@ -236,7 +242,7 @@ with gr.Blocks(
|
|
| 236 |
"local",
|
| 237 |
"remote",
|
| 238 |
],
|
| 239 |
-
value=
|
| 240 |
label="Compute Mode",
|
| 241 |
info=(None if torch.cuda.is_available() else "NOTE: local mode requires GPU"),
|
| 242 |
render=False,
|
|
@@ -444,10 +450,6 @@ with gr.Blocks(
|
|
| 444 |
"""Return updated value for a component"""
|
| 445 |
return gr.update(value=value)
|
| 446 |
|
| 447 |
-
def set_compute(compute_mode):
|
| 448 |
-
global COMPUTE
|
| 449 |
-
COMPUTE = compute_mode
|
| 450 |
-
|
| 451 |
def set_avatar(compute_mode):
|
| 452 |
if compute_mode == "remote":
|
| 453 |
image_file = "images/cloud.png"
|
|
@@ -475,13 +477,6 @@ with gr.Blocks(
|
|
| 475 |
# Display the content in the textbox
|
| 476 |
return content, change_visibility(True)
|
| 477 |
|
| 478 |
-
# def update_citations(citations):
|
| 479 |
-
# if citations == []:
|
| 480 |
-
# # Blank out and hide the citations textbox when new input is submitted
|
| 481 |
-
# return "", change_visibility(False)
|
| 482 |
-
# else:
|
| 483 |
-
# return citations, change_visibility(True)
|
| 484 |
-
|
| 485 |
# --------------
|
| 486 |
# Event handlers
|
| 487 |
# --------------
|
|
@@ -495,11 +490,6 @@ with gr.Blocks(
|
|
| 495 |
return component.clear()
|
| 496 |
|
| 497 |
compute_mode.change(
|
| 498 |
-
# Update global COMPUTE variable
|
| 499 |
-
set_compute,
|
| 500 |
-
[compute_mode],
|
| 501 |
-
api_name=False,
|
| 502 |
-
).then(
|
| 503 |
# Change the app status text
|
| 504 |
get_status_text,
|
| 505 |
[compute_mode],
|
|
@@ -527,7 +517,7 @@ with gr.Blocks(
|
|
| 527 |
input.submit(
|
| 528 |
# Submit input to the chatbot
|
| 529 |
to_workflow,
|
| 530 |
-
[input, chatbot, thread_id],
|
| 531 |
[chatbot, retrieved_emails, citations_text],
|
| 532 |
api_name=False,
|
| 533 |
)
|
|
@@ -661,6 +651,9 @@ with gr.Blocks(
|
|
| 661 |
)
|
| 662 |
# fmt: on
|
| 663 |
|
|
|
|
|
|
|
|
|
|
| 664 |
|
| 665 |
if __name__ == "__main__":
|
| 666 |
|
|
|
|
| 18 |
# Setup environment variables
|
| 19 |
load_dotenv(dotenv_path=".env", override=True)
|
| 20 |
|
| 21 |
+
# Global setting for search type
|
|
|
|
| 22 |
search_type = "hybrid"
|
| 23 |
|
| 24 |
+
# Global variables for LangChain graph: use dictionaries to store user-specific instances
|
| 25 |
+
# https://www.gradio.app/guides/state-in-blocks
|
| 26 |
+
graph_instances = {"local": {}, "remote": {}}
|
| 27 |
|
| 28 |
|
| 29 |
+
def cleanup_graph(request: gr.Request):
|
| 30 |
+
if request.session_hash in graph_instances["local"]:
|
| 31 |
+
del graph_instances["local"][request.session_hash]
|
| 32 |
+
print(f"Deleted local graph for session {request.session_hash}")
|
| 33 |
+
if request.session_hash in graph_instances["remote"]:
|
| 34 |
+
del graph_instances["remote"][request.session_hash]
|
| 35 |
+
print(f"Deleted remote graph for session {request.session_hash}")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def run_workflow(input, history, compute_mode, thread_id, session_hash):
|
| 39 |
"""The main function to run the chat workflow"""
|
| 40 |
|
| 41 |
+
# Error if user tries to run local mode without GPU
|
| 42 |
+
if compute_mode == "local":
|
|
|
|
|
|
|
|
|
|
| 43 |
if not torch.cuda.is_available():
|
| 44 |
raise gr.Error(
|
| 45 |
"Local mode requires GPU. Please select remote mode.",
|
| 46 |
print_exception=False,
|
| 47 |
)
|
| 48 |
+
|
| 49 |
+
# Get graph for compute mode
|
| 50 |
+
graph = graph_instances[compute_mode].get(session_hash)
|
| 51 |
+
if graph is not None:
|
| 52 |
+
print(f"Get {compute_mode} graph for session {session_hash}")
|
| 53 |
|
| 54 |
if graph is None:
|
| 55 |
# Notify when we're loading the local model because it takes some time
|
| 56 |
+
if compute_mode == "local":
|
| 57 |
gr.Info(
|
| 58 |
f"Please wait for the local model to load",
|
| 59 |
duration=15,
|
| 60 |
title=f"Model loading...",
|
| 61 |
)
|
| 62 |
# Get the chat model and build the graph
|
| 63 |
+
chat_model = GetChatModel(compute_mode)
|
| 64 |
+
graph_builder = BuildGraph(chat_model, compute_mode, search_type)
|
| 65 |
# Compile the graph with an in-memory checkpointer
|
| 66 |
memory = MemorySaver()
|
| 67 |
graph = graph_builder.compile(checkpointer=memory)
|
| 68 |
# Set global graph for compute mode
|
| 69 |
+
graph_instances[compute_mode][session_hash] = graph
|
| 70 |
+
print(f"Set {compute_mode} graph for session {session_hash}")
|
| 71 |
+
# Notify when model finishes loading
|
| 72 |
+
gr.Success(f"{compute_mode}", duration=4, title=f"Model loaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
print(f"Using thread_id: {thread_id}")
|
| 75 |
|
|
|
|
| 183 |
yield history, None, citations
|
| 184 |
|
| 185 |
|
| 186 |
+
def to_workflow(request: gr.Request, *args):
|
| 187 |
"""Wrapper function to call function with or without @spaces.GPU"""
|
| 188 |
+
compute_mode = args[2]
|
| 189 |
+
# Add session_hash to arguments
|
| 190 |
+
new_args = args + (request.session_hash,)
|
| 191 |
+
if compute_mode == "local":
|
| 192 |
+
for value in run_workflow_local(*new_args):
|
| 193 |
yield value
|
| 194 |
+
if compute_mode == "remote":
|
| 195 |
+
for value in run_workflow_remote(*new_args):
|
| 196 |
yield value
|
| 197 |
|
| 198 |
|
|
|
|
| 242 |
"local",
|
| 243 |
"remote",
|
| 244 |
],
|
| 245 |
+
value=("local" if torch.cuda.is_available() else "remote"),
|
| 246 |
label="Compute Mode",
|
| 247 |
info=(None if torch.cuda.is_available() else "NOTE: local mode requires GPU"),
|
| 248 |
render=False,
|
|
|
|
| 450 |
"""Return updated value for a component"""
|
| 451 |
return gr.update(value=value)
|
| 452 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
def set_avatar(compute_mode):
|
| 454 |
if compute_mode == "remote":
|
| 455 |
image_file = "images/cloud.png"
|
|
|
|
| 477 |
# Display the content in the textbox
|
| 478 |
return content, change_visibility(True)
|
| 479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
# --------------
|
| 481 |
# Event handlers
|
| 482 |
# --------------
|
|
|
|
| 490 |
return component.clear()
|
| 491 |
|
| 492 |
compute_mode.change(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
# Change the app status text
|
| 494 |
get_status_text,
|
| 495 |
[compute_mode],
|
|
|
|
| 517 |
input.submit(
|
| 518 |
# Submit input to the chatbot
|
| 519 |
to_workflow,
|
| 520 |
+
[input, chatbot, compute_mode, thread_id],
|
| 521 |
[chatbot, retrieved_emails, citations_text],
|
| 522 |
api_name=False,
|
| 523 |
)
|
|
|
|
| 651 |
)
|
| 652 |
# fmt: on
|
| 653 |
|
| 654 |
+
# Clean up graph instances when page is closed/refreshed
|
| 655 |
+
demo.unload(cleanup_graph)
|
| 656 |
+
|
| 657 |
|
| 658 |
if __name__ == "__main__":
|
| 659 |
|