Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from main import GetChatModel | |
| from graph import BuildGraph | |
| from retriever import db_dir | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from dotenv import load_dotenv | |
| from main import openai_model, model_id | |
| from util import get_sources, get_start_end_months | |
| import zipfile | |
| import shutil | |
| import spaces | |
| import torch | |
| import boto3 | |
| import uuid | |
| import ast | |
| import os | |
| # Setup environment variables | |
| load_dotenv(dotenv_path=".env", override=True) | |
| # Global setting for search type | |
| search_type = "hybrid" | |
| # Global variables for LangChain graph: use dictionaries to store user-specific instances | |
| # https://www.gradio.app/guides/state-in-blocks | |
| graph_instances = {"local": {}, "remote": {}} | |
| def cleanup_graph(request: gr.Request): | |
| if request.session_hash in graph_instances["local"]: | |
| del graph_instances["local"][request.session_hash] | |
| print(f"Deleted local graph for session {request.session_hash}") | |
| if request.session_hash in graph_instances["remote"]: | |
| del graph_instances["remote"][request.session_hash] | |
| print(f"Deleted remote graph for session {request.session_hash}") | |
| def run_workflow(input, history, compute_mode, thread_id, session_hash): | |
| """The main function to run the chat workflow""" | |
| # Error if user tries to run local mode without GPU | |
| if compute_mode == "local": | |
| if not torch.cuda.is_available(): | |
| raise gr.Error( | |
| "Local mode requires GPU. Please select remote mode.", | |
| print_exception=False, | |
| ) | |
| # Get graph for compute mode | |
| graph = graph_instances[compute_mode].get(session_hash) | |
| if graph is not None: | |
| print(f"Get {compute_mode} graph for session {session_hash}") | |
| if graph is None: | |
| # Notify when we're loading the local model because it takes some time | |
| if compute_mode == "local": | |
| gr.Info( | |
| f"Please wait for the local model to load", | |
| duration=15, | |
| title=f"Model loading...", | |
| ) | |
| # Get the chat model and build the graph | |
| chat_model = GetChatModel(compute_mode) | |
| graph_builder = BuildGraph(chat_model, compute_mode, search_type) | |
| # Compile the graph with an in-memory checkpointer | |
| memory = MemorySaver() | |
| graph = graph_builder.compile(checkpointer=memory) | |
| # Set global graph for compute mode | |
| graph_instances[compute_mode][session_hash] = graph | |
| print(f"Set {compute_mode} graph for session {session_hash}") | |
| # Notify when model finishes loading | |
| gr.Success(f"{compute_mode}", duration=4, title=f"Model loaded") | |
| print(f"Using thread_id: {thread_id}") | |
| # Display the user input in the chatbot | |
| history.append(gr.ChatMessage(role="user", content=input)) | |
| # Return the message history and empty lists for emails and citations texboxes | |
| yield history, [], [] | |
| # Stream graph steps for a single input | |
| # https://langchain-ai.lang.chat/langgraph/reference/graphs/#langgraph.graph.state.CompiledStateGraph | |
| for step in graph.stream( | |
| # Appends the user input to the graph state | |
| {"messages": [{"role": "user", "content": input}]}, | |
| config={"configurable": {"thread_id": thread_id}}, | |
| ): | |
| # Get the node name and output chunk | |
| node, chunk = next(iter(step.items())) | |
| if node == "query": | |
| # Get the message (AIMessage class in LangChain) | |
| chunk_messages = chunk["messages"] | |
| # Look for tool calls | |
| if chunk_messages.tool_calls: | |
| # Loop over tool calls | |
| for tool_call in chunk_messages.tool_calls: | |
| # Show the tool call with arguments used | |
| args = tool_call["args"] | |
| content = args["search_query"] if "search_query" in args else "" | |
| start_year = args["start_year"] if "start_year" in args else None | |
| end_year = args["end_year"] if "end_year" in args else None | |
| if start_year or end_year: | |
| content = f"{content} ({start_year or ''} - {end_year or ''})" | |
| if "months" in args: | |
| content = f"{content} {args['months']}" | |
| history.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content=content, | |
| metadata={"title": f"π Running tool {tool_call['name']}"}, | |
| ) | |
| ) | |
| if chunk_messages.content: | |
| # Display response made instead of or in addition to a tool call | |
| history.append( | |
| gr.ChatMessage(role="assistant", content=chunk_messages.content) | |
| ) | |
| yield history, [], [] | |
| if node == "retrieve_emails": | |
| chunk_messages = chunk["messages"] | |
| # Loop over tool calls | |
| count = 0 | |
| retrieved_emails = [] | |
| for message in chunk_messages: | |
| count += 1 | |
| # Get the retrieved emails as a list | |
| email_list = message.content.replace( | |
| "### Retrieved Emails:\n\n\n\n", "" | |
| ).split("--- --- --- --- Next Email --- --- --- ---\n\n") | |
| # Get the list of source files (e.g. R-help/2024-December.txt) for retrieved emails | |
| month_list = [text.splitlines()[0] for text in email_list] | |
| # Format months (e.g. 2024-December) into text | |
| month_text = ( | |
| ", ".join(month_list).replace("R-help/", "").replace(".txt", "") | |
| ) | |
| # Get the number of retrieved emails | |
| n_emails = len(email_list) | |
| title = f"π Retrieved {n_emails} emails" | |
| if email_list[0] == "### No emails were retrieved": | |
| title = "β Retrieved 0 emails" | |
| history.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content=month_text, | |
| metadata={"title": title}, | |
| ) | |
| ) | |
| # Format the retrieved emails with Tool Call heading | |
| retrieved_emails.append( | |
| message.content.replace( | |
| "### Retrieved Emails:\n\n\n\n", | |
| f"### ### ### ### Tool Call {count} ### ### ### ###\n\n", | |
| ) | |
| ) | |
| # Combine all the Tool Call results | |
| retrieved_emails = "\n\n".join(retrieved_emails) | |
| yield history, retrieved_emails, [] | |
| if node == "generate": | |
| chunk_messages = chunk["messages"] | |
| # Chat response without citations | |
| if chunk_messages.content: | |
| history.append( | |
| gr.ChatMessage(role="assistant", content=chunk_messages.content) | |
| ) | |
| # None is used for no change to the retrieved emails textbox | |
| yield history, None, [] | |
| if node == "answer_with_citations": | |
| chunk_messages = chunk["messages"][0] | |
| # Parse the message for the answer and citations | |
| try: | |
| answer, citations = ast.literal_eval(chunk_messages.content) | |
| except: | |
| # In case we got an answer without citations | |
| answer = chunk_messages.content | |
| citations = None | |
| history.append(gr.ChatMessage(role="assistant", content=answer)) | |
| yield history, None, citations | |
| def to_workflow(request: gr.Request, *args): | |
| """Wrapper function to call function with or without @spaces.GPU""" | |
| compute_mode = args[2] | |
| # Add session_hash to arguments | |
| new_args = args + (request.session_hash,) | |
| if compute_mode == "local": | |
| for value in run_workflow_local(*new_args): | |
| yield value | |
| if compute_mode == "remote": | |
| for value in run_workflow_remote(*new_args): | |
| yield value | |
| def run_workflow_local(*args): | |
| for value in run_workflow(*args): | |
| yield value | |
| def run_workflow_remote(*args): | |
| for value in run_workflow(*args): | |
| yield value | |
| # Custom CSS for bottom alignment | |
| css = """ | |
| .row-container { | |
| display: flex; | |
| align-items: flex-end; /* Align components at the bottom */ | |
| gap: 10px; /* Add spacing between components */ | |
| } | |
| """ | |
| with gr.Blocks( | |
| title="R-help-chat", | |
| # Noto Color Emoji gets a nice-looking Unicode Character βπ·β (U+1F1F7) on Chrome | |
| theme=gr.themes.Soft( | |
| font=[ | |
| "ui-sans-serif", | |
| "system-ui", | |
| "sans-serif", | |
| "Apple Color Emoji", | |
| "Segoe UI Emoji", | |
| "Segoe UI Symbol", | |
| "Noto Color Emoji", | |
| ] | |
| ), | |
| css=css, | |
| ) as demo: | |
| # ----------------- | |
| # Define components | |
| # ----------------- | |
| compute_mode = gr.Radio( | |
| choices=[ | |
| "local", | |
| "remote", | |
| ], | |
| value=("local" if torch.cuda.is_available() else "remote"), | |
| label="Compute Mode", | |
| info=(None if torch.cuda.is_available() else "NOTE: local mode requires GPU"), | |
| render=False, | |
| ) | |
| downloading = gr.Textbox( | |
| lines=1, | |
| label="Downloading Data, Please Wait", | |
| visible=False, | |
| render=False, | |
| ) | |
| extracting = gr.Textbox( | |
| lines=1, | |
| label="Extracting Data, Please Wait", | |
| visible=False, | |
| render=False, | |
| ) | |
| data_error = gr.Textbox( | |
| value="App is unavailable because data could not be loaded. Try reloading the page, then contact the maintainer if the problem persists.", | |
| lines=1, | |
| label="Error downloading or extracting data", | |
| visible=False, | |
| render=False, | |
| ) | |
| show_examples = gr.Checkbox( | |
| value=False, | |
| label="π‘ Example Questions", | |
| render=False, | |
| ) | |
| chatbot = gr.Chatbot( | |
| type="messages", | |
| show_label=False, | |
| avatar_images=( | |
| None, | |
| ( | |
| "images/cloud.png" | |
| if compute_mode.value == "remote" | |
| else "images/chip.png" | |
| ), | |
| ), | |
| show_copy_all_button=True, | |
| render=False, | |
| ) | |
| # Modified from gradio/chat_interface.py | |
| input = gr.Textbox( | |
| show_label=False, | |
| label="Message", | |
| placeholder="Type a message...", | |
| scale=7, | |
| autofocus=True, | |
| submit_btn=True, | |
| render=False, | |
| ) | |
| emails_textbox = gr.Textbox( | |
| label="Retrieved Emails", | |
| info="Tip: Look for 'Tool Call' and 'Next Email' separators. Quoted lines (starting with '>') are removed before indexing.", | |
| lines=10, | |
| visible=False, | |
| render=False, | |
| ) | |
| citations_textbox = gr.Textbox( | |
| label="Citations", | |
| lines=2, | |
| visible=False, | |
| render=False, | |
| ) | |
| # ------------ | |
| # Set up state | |
| # ------------ | |
| def generate_thread_id(): | |
| """Generate a new thread ID""" | |
| thread_id = uuid.uuid4() | |
| print(f"Generated thread_id: {thread_id}") | |
| return thread_id | |
| # Define thread_id variable | |
| thread_id = gr.State(generate_thread_id()) | |
| # Define states for the output textboxes | |
| retrieved_emails = gr.State([]) | |
| citations_text = gr.State([]) | |
| # ------------------ | |
| # Make the interface | |
| # ------------------ | |
| def get_intro_text(): | |
| intro = f"""<!-- # π€ R-help-chat --> | |
| <!-- Get AI-powered answers about R programming backed by email retrieval. --> | |
| ## π·π€π¬ R-help-chat | |
| **Chat with the [R-help mailing list archives](https://stat.ethz.ch/pipermail/r-help/).** | |
| An LLM turns your question into a search query, including year ranges, and generates an answer from the retrieved emails. | |
| You can ask follow-up questions with the chat history as context. | |
| β‘οΈ To clear the history and start a new chat, press the ποΈ clear button. | |
| **_Answers may be incorrect._** | |
| """ | |
| return intro | |
| def get_status_text(compute_mode): | |
| if compute_mode == "remote": | |
| status_text = f""" | |
| π Now in **remote** mode, using the OpenAI API<br> | |
| β οΈ **_Privacy Notice_**: Data sharing with OpenAI is enabled<br> | |
| β¨ text-embedding-3-small and {openai_model}<br> | |
| π See the project's [GitHub repository](https://github.com/jedick/R-help-chat) | |
| """ | |
| if compute_mode == "local": | |
| status_text = f""" | |
| π Now in **local** mode, using ZeroGPU hardware<br> | |
| β Response time is around 2 minutes<br> | |
| β¨ [Nomic](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) embeddings and [{model_id}](https://huggingface.co/{model_id})<br> | |
| π See the project's [GitHub repository](https://github.com/jedick/R-help-chat) | |
| """ | |
| return status_text | |
| def get_info_text(): | |
| try: | |
| # Get source files for each email and start and end months from database | |
| sources = get_sources() | |
| start, end = get_start_end_months(sources) | |
| except: | |
| # If database isn't ready, put in empty values | |
| sources = [] | |
| start = None | |
| end = None | |
| info_text = f""" | |
| **Database:** {len(sources)} emails from {start} to {end}. | |
| **Features:** RAG, today's date, hybrid search (dense+sparse), query analysis, | |
| multiple retrievals per turn (remote mode), answer with citations (remote mode), chat memory. | |
| **Tech:** LangChain + Hugging Face + Gradio; ChromaDB and BM25S-based retrievers.<br> | |
| """ | |
| return info_text | |
| with gr.Row(): | |
| # Left column: Intro, Compute, Chat | |
| with gr.Column(scale=2): | |
| with gr.Row(elem_classes=["row-container"]): | |
| with gr.Column(scale=2): | |
| intro = gr.Markdown(get_intro_text()) | |
| with gr.Column(scale=1): | |
| compute_mode.render() | |
| with gr.Group(visible=False) as chat_interface: | |
| chatbot.render() | |
| input.render() | |
| # Render textboxes for data loading progress | |
| downloading.render() | |
| extracting.render() | |
| data_error.render() | |
| # Right column: Info, Examples | |
| with gr.Column(scale=1): | |
| status = gr.Markdown(get_status_text(compute_mode.value)) | |
| with gr.Accordion("βΉοΈ More Info", open=False): | |
| info = gr.Markdown(get_info_text()) | |
| with gr.Accordion("π‘ Examples", open=True): | |
| # Add some helpful examples | |
| example_questions = [ | |
| # "What is today's date?", | |
| "Summarize emails from the last two months", | |
| "How to use plotmath?", | |
| "When was has.HLC mentioned?", | |
| "Who reported installation problems in 2023-2024?", | |
| ] | |
| gr.Examples( | |
| examples=[[q] for q in example_questions], | |
| inputs=[input], | |
| label="Click an example to fill the message box", | |
| elem_id="example-questions", | |
| ) | |
| multi_tool_questions = [ | |
| "Differences between lapply and for loops", | |
| "Compare usage of pipe operator between 2022 and 2024", | |
| ] | |
| gr.Examples( | |
| examples=[[q] for q in multi_tool_questions], | |
| inputs=[input], | |
| label="Multiple retrievals (remote mode)", | |
| elem_id="example-questions", | |
| ) | |
| multi_turn_questions = [ | |
| "Lookup emails that reference bugs.r-project.org in 2025", | |
| "Did those authors report bugs before 2025?", | |
| ] | |
| gr.Examples( | |
| examples=[[q] for q in multi_turn_questions], | |
| inputs=[input], | |
| label="Asking follow-up questions", | |
| elem_id="example-questions", | |
| ) | |
| # Bottom row: retrieved emails and citations | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| emails_textbox.render() | |
| with gr.Column(scale=1): | |
| citations_textbox.render() | |
| # ------------- | |
| # App functions | |
| # ------------- | |
| def value(value): | |
| """Return updated value for a component""" | |
| return gr.update(value=value) | |
| def set_avatar(compute_mode): | |
| if compute_mode == "remote": | |
| image_file = "images/cloud.png" | |
| if compute_mode == "local": | |
| image_file = "images/chip.png" | |
| return gr.update( | |
| avatar_images=( | |
| None, | |
| image_file, | |
| ), | |
| ) | |
| def change_visibility(visible): | |
| """Return updated visibility state for a component""" | |
| return gr.update(visible=visible) | |
| def update_textbox(content, textbox): | |
| if content is None: | |
| # Keep the content of the textbox unchanged | |
| return textbox, change_visibility(True) | |
| elif content == []: | |
| # Blank out the textbox | |
| return "", change_visibility(False) | |
| else: | |
| # Display the content in the textbox | |
| return content, change_visibility(True) | |
| # -------------- | |
| # Event handlers | |
| # -------------- | |
| # Start a new thread when the user presses the clear (trash) button | |
| # https://github.com/gradio-app/gradio/issues/9722 | |
| chatbot.clear(generate_thread_id, outputs=[thread_id], api_name=False) | |
| def clear_component(component): | |
| """Return cleared component""" | |
| return component.clear() | |
| compute_mode.change( | |
| # Change the app status text | |
| get_status_text, | |
| [compute_mode], | |
| [status], | |
| api_name=False, | |
| ).then( | |
| # Clear the chatbot history | |
| clear_component, | |
| [chatbot], | |
| [chatbot], | |
| api_name=False, | |
| ).then( | |
| # Change the chatbot avatar | |
| set_avatar, | |
| [compute_mode], | |
| [chatbot], | |
| api_name=False, | |
| ).then( | |
| # Start a new thread | |
| generate_thread_id, | |
| outputs=[thread_id], | |
| api_name=False, | |
| ) | |
| input.submit( | |
| # Submit input to the chatbot | |
| to_workflow, | |
| [input, chatbot, compute_mode, thread_id], | |
| [chatbot, retrieved_emails, citations_text], | |
| api_name=False, | |
| ) | |
| retrieved_emails.change( | |
| # Update the emails textbox | |
| update_textbox, | |
| [retrieved_emails, emails_textbox], | |
| [emails_textbox, emails_textbox], | |
| api_name=False, | |
| ) | |
| citations_text.change( | |
| # Update the citations textbox | |
| update_textbox, | |
| [citations_text, citations_textbox], | |
| [citations_textbox, citations_textbox], | |
| api_name=False, | |
| ) | |
| # ------------ | |
| # Data loading | |
| # ------------ | |
| def download(): | |
| """Download the application data""" | |
| # Code from https://thecodinginterface.com/blog/aws-s3-python-boto3 | |
| def aws_session(region_name="us-east-1"): | |
| return boto3.session.Session( | |
| aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), | |
| aws_secret_access_key=os.getenv("AWS_ACCESS_KEY_SECRET"), | |
| region_name=region_name, | |
| ) | |
| def download_file_from_bucket(bucket_name, s3_key, dst_path): | |
| session = aws_session() | |
| s3_resource = session.resource("s3") | |
| bucket = s3_resource.Bucket(bucket_name) | |
| bucket.download_file(Key=s3_key, Filename=dst_path) | |
| if not os.path.isdir(db_dir): | |
| if not os.path.exists("db.zip"): | |
| download_file_from_bucket("r-help-chat", "db.zip", "db.zip") | |
| return None | |
| def extract(): | |
| """Extract the db.zip file""" | |
| if not os.path.isdir(db_dir): | |
| file_path = "db.zip" | |
| extract_to_path = "./" | |
| try: | |
| with zipfile.ZipFile(file_path, "r") as zip_ref: | |
| zip_ref.extractall(extract_to_path) | |
| except: | |
| # If there were any errors, remove zip file and db directory | |
| # to initiate a new download when app is reloaded | |
| try: | |
| os.remove(file_path) | |
| print(f"{file_path} has been deleted.") | |
| except FileNotFoundError: | |
| print(f"{file_path} does not exist.") | |
| except PermissionError: | |
| print(f"Permission denied to delete {file_path}.") | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| directory_path = "./db" | |
| try: | |
| # Forcefully and recursively delete a directory, like rm -rf | |
| shutil.rmtree(directory_path) | |
| print(f"Successfully deleted: {directory_path}") | |
| except FileNotFoundError: | |
| print(f"Directory not found: {directory_path}") | |
| except PermissionError: | |
| print(f"Permission denied: {directory_path}") | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| return None | |
| def is_data_present(): | |
| """Check if the database directory is present""" | |
| return os.path.isdir(db_dir) | |
| def is_data_missing(): | |
| """Check if the database directory is missing""" | |
| return not os.path.isdir(db_dir) | |
| false = gr.State(False) | |
| need_data = gr.State() | |
| have_data = gr.State() | |
| # When app is launched: check if data is present, download and extract it | |
| # if necessary, make chat interface visible, update database info, and show | |
| # error textbox if data loading failed. | |
| # fmt: off | |
| demo.load( | |
| is_data_missing, None, [need_data], api_name=False | |
| ).then( | |
| change_visibility, [need_data], [downloading], api_name=False | |
| ).then( | |
| download, None, [downloading], api_name=False | |
| ).then( | |
| change_visibility, [false], [downloading], api_name=False | |
| ).then( | |
| change_visibility, [need_data], [extracting], api_name=False | |
| ).then( | |
| extract, None, [extracting], api_name=False | |
| ).then( | |
| change_visibility, [false], [extracting], api_name=False | |
| ).then( | |
| is_data_present, None, [have_data], api_name=False | |
| ).then( | |
| change_visibility, [have_data], [chat_interface], api_name=False | |
| ).then( | |
| get_info_text, None, [info], api_name=False | |
| ).then( | |
| is_data_missing, None, [need_data], api_name=False | |
| ).then( | |
| change_visibility, [need_data], [data_error], api_name=False | |
| ) | |
| # fmt: on | |
| # Clean up graph instances when page is closed/refreshed | |
| demo.unload(cleanup_graph) | |
| if __name__ == "__main__": | |
| # Set allowed_paths to serve chatbot avatar images | |
| current_directory = os.getcwd() | |
| allowed_paths = [current_directory + "/images"] | |
| # Launch the Gradio app | |
| demo.launch(allowed_paths=allowed_paths) | |