Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
·
503a0b6
1
Parent(s):
17ad0bb
Improve args for thinking mode
Browse files- app.py +18 -2
- graph.py +11 -22
- main.py +2 -4
- mods/tool_calling_llm.py +1 -3
- prompts.py +56 -48
app.py
CHANGED
|
@@ -87,7 +87,9 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
|
|
| 87 |
)
|
| 88 |
# Get the chat model and build the graph
|
| 89 |
chat_model = GetChatModel(compute_mode)
|
| 90 |
-
graph_builder = BuildGraph(
|
|
|
|
|
|
|
| 91 |
# Compile the graph with an in-memory checkpointer
|
| 92 |
memory = MemorySaver()
|
| 93 |
graph = graph_builder.compile(checkpointer=memory)
|
|
@@ -398,7 +400,7 @@ with gr.Blocks(
|
|
| 398 |
end = None
|
| 399 |
info_text = f"""
|
| 400 |
**Database:** {len(sources)} emails from {start} to {end}.
|
| 401 |
-
**Features:** RAG, today's date, hybrid search (dense+sparse), thinking
|
| 402 |
multiple retrievals per turn (remote), answer with citations (remote), chat memory.
|
| 403 |
**Tech:** LangChain + Hugging Face + Gradio; ChromaDB and BM25S-based retrievers.<br>
|
| 404 |
"""
|
|
@@ -537,6 +539,12 @@ with gr.Blocks(
|
|
| 537 |
generate_thread_id,
|
| 538 |
outputs=[thread_id],
|
| 539 |
api_name=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
)
|
| 541 |
|
| 542 |
input.submit(
|
|
@@ -563,6 +571,14 @@ with gr.Blocks(
|
|
| 563 |
api_name=False,
|
| 564 |
)
|
| 565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
# ------------
|
| 567 |
# Data loading
|
| 568 |
# ------------
|
|
|
|
| 87 |
)
|
| 88 |
# Get the chat model and build the graph
|
| 89 |
chat_model = GetChatModel(compute_mode)
|
| 90 |
+
graph_builder = BuildGraph(
|
| 91 |
+
chat_model, compute_mode, search_type, think_query=True
|
| 92 |
+
)
|
| 93 |
# Compile the graph with an in-memory checkpointer
|
| 94 |
memory = MemorySaver()
|
| 95 |
graph = graph_builder.compile(checkpointer=memory)
|
|
|
|
| 400 |
end = None
|
| 401 |
info_text = f"""
|
| 402 |
**Database:** {len(sources)} emails from {start} to {end}.
|
| 403 |
+
**Features:** RAG, today's date, hybrid search (dense+sparse), thinking output (local),
|
| 404 |
multiple retrievals per turn (remote), answer with citations (remote), chat memory.
|
| 405 |
**Tech:** LangChain + Hugging Face + Gradio; ChromaDB and BM25S-based retrievers.<br>
|
| 406 |
"""
|
|
|
|
| 539 |
generate_thread_id,
|
| 540 |
outputs=[thread_id],
|
| 541 |
api_name=False,
|
| 542 |
+
).then(
|
| 543 |
+
# Focus textbox by updating the textbox with the current value
|
| 544 |
+
lambda x: gr.update(value=x),
|
| 545 |
+
[input],
|
| 546 |
+
[input],
|
| 547 |
+
api_name=False,
|
| 548 |
)
|
| 549 |
|
| 550 |
input.submit(
|
|
|
|
| 571 |
api_name=False,
|
| 572 |
)
|
| 573 |
|
| 574 |
+
chatbot.clear(
|
| 575 |
+
# Focus textbox when the chatbot is cleared
|
| 576 |
+
lambda x: gr.update(value=x),
|
| 577 |
+
[input],
|
| 578 |
+
[input],
|
| 579 |
+
api_name=False,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
# ------------
|
| 583 |
# Data loading
|
| 584 |
# ------------
|
graph.py
CHANGED
|
@@ -71,23 +71,17 @@ def normalize_messages(messages):
|
|
| 71 |
return messages
|
| 72 |
|
| 73 |
|
| 74 |
-
def ToolifyHF(chat_model, system_message, system_message_suffix=""
|
| 75 |
"""
|
| 76 |
Get a Hugging Face model ready for bind_tools().
|
| 77 |
"""
|
| 78 |
|
| 79 |
-
## Add /no_think flag to turn off thinking mode (SmolLM3 and Qwen)
|
| 80 |
-
# if not think:
|
| 81 |
-
# system_message = "/no_think\n" + system_message
|
| 82 |
-
|
| 83 |
# Combine system prompt and tools template
|
| 84 |
tool_system_prompt_template = system_message + generic_tools_template
|
| 85 |
|
| 86 |
class HuggingFaceWithTools(ToolCallingLLM, ChatHuggingFace):
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
# Allows adding attributes dynamically
|
| 90 |
-
extra = "allow"
|
| 91 |
|
| 92 |
chat_model = HuggingFaceWithTools(
|
| 93 |
llm=chat_model.llm,
|
|
@@ -96,9 +90,6 @@ def ToolifyHF(chat_model, system_message, system_message_suffix="", think=False)
|
|
| 96 |
system_message_suffix=system_message_suffix,
|
| 97 |
)
|
| 98 |
|
| 99 |
-
# The "model" attribute is needed for ToolCallingLLM to print the response if it can't be parsed
|
| 100 |
-
chat_model.model = chat_model.model_id + "_for_tools"
|
| 101 |
-
|
| 102 |
return chat_model
|
| 103 |
|
| 104 |
|
|
@@ -107,8 +98,7 @@ def BuildGraph(
|
|
| 107 |
compute_mode,
|
| 108 |
search_type,
|
| 109 |
top_k=6,
|
| 110 |
-
|
| 111 |
-
think_generate=False,
|
| 112 |
):
|
| 113 |
"""
|
| 114 |
Build conversational RAG graph for email retrieval and answering with citations.
|
|
@@ -118,8 +108,7 @@ def BuildGraph(
|
|
| 118 |
compute_mode: remote or local (for retriever)
|
| 119 |
search_type: dense, sparse, or hybrid (for retriever)
|
| 120 |
top_k: number of documents to retrieve
|
| 121 |
-
|
| 122 |
-
think_generate: Whether to use thinking mode for generation
|
| 123 |
|
| 124 |
Based on:
|
| 125 |
https://python.langchain.com/docs/how_to/qa_sources
|
|
@@ -206,7 +195,7 @@ def BuildGraph(
|
|
| 206 |
if is_local:
|
| 207 |
# For local models (ChatHuggingFace with SmolLM, Gemma, or Qwen)
|
| 208 |
query_model = ToolifyHF(
|
| 209 |
-
chat_model, query_prompt(
|
| 210 |
).bind_tools([retrieve_emails])
|
| 211 |
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
| 212 |
generate_model = chat_model
|
|
@@ -227,7 +216,7 @@ def BuildGraph(
|
|
| 227 |
messages = normalize_messages(messages)
|
| 228 |
# print_message_summaries(messages, "--- query: after normalization ---")
|
| 229 |
else:
|
| 230 |
-
messages = [SystemMessage(query_prompt(
|
| 231 |
response = query_model.invoke(messages)
|
| 232 |
|
| 233 |
return {"messages": response}
|
|
@@ -239,12 +228,12 @@ def BuildGraph(
|
|
| 239 |
# print_message_summaries(messages, "--- generate: before normalization ---")
|
| 240 |
messages = normalize_messages(messages)
|
| 241 |
# Add the system message here because we're not using tools
|
| 242 |
-
messages = [
|
| 243 |
-
SystemMessage(generate_prompt(with_tools=False, think=False))
|
| 244 |
-
] + messages
|
| 245 |
# print_message_summaries(messages, "--- generate: after normalization ---")
|
| 246 |
else:
|
| 247 |
-
messages = [
|
|
|
|
|
|
|
| 248 |
response = generate_model.invoke(messages)
|
| 249 |
|
| 250 |
return {"messages": response}
|
|
|
|
| 71 |
return messages
|
| 72 |
|
| 73 |
|
| 74 |
+
def ToolifyHF(chat_model, system_message, system_message_suffix=""):
|
| 75 |
"""
|
| 76 |
Get a Hugging Face model ready for bind_tools().
|
| 77 |
"""
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# Combine system prompt and tools template
|
| 80 |
tool_system_prompt_template = system_message + generic_tools_template
|
| 81 |
|
| 82 |
class HuggingFaceWithTools(ToolCallingLLM, ChatHuggingFace):
|
| 83 |
+
def __init__(self, **kwargs):
|
| 84 |
+
super().__init__(**kwargs)
|
|
|
|
|
|
|
| 85 |
|
| 86 |
chat_model = HuggingFaceWithTools(
|
| 87 |
llm=chat_model.llm,
|
|
|
|
| 90 |
system_message_suffix=system_message_suffix,
|
| 91 |
)
|
| 92 |
|
|
|
|
|
|
|
|
|
|
| 93 |
return chat_model
|
| 94 |
|
| 95 |
|
|
|
|
| 98 |
compute_mode,
|
| 99 |
search_type,
|
| 100 |
top_k=6,
|
| 101 |
+
think_query=False,
|
|
|
|
| 102 |
):
|
| 103 |
"""
|
| 104 |
Build conversational RAG graph for email retrieval and answering with citations.
|
|
|
|
| 108 |
compute_mode: remote or local (for retriever)
|
| 109 |
search_type: dense, sparse, or hybrid (for retriever)
|
| 110 |
top_k: number of documents to retrieve
|
| 111 |
+
think_query: Whether to use thinking mode for query
|
|
|
|
| 112 |
|
| 113 |
Based on:
|
| 114 |
https://python.langchain.com/docs/how_to/qa_sources
|
|
|
|
| 195 |
if is_local:
|
| 196 |
# For local models (ChatHuggingFace with SmolLM, Gemma, or Qwen)
|
| 197 |
query_model = ToolifyHF(
|
| 198 |
+
chat_model, query_prompt(chat_model, think=think_query), ""
|
| 199 |
).bind_tools([retrieve_emails])
|
| 200 |
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
| 201 |
generate_model = chat_model
|
|
|
|
| 216 |
messages = normalize_messages(messages)
|
| 217 |
# print_message_summaries(messages, "--- query: after normalization ---")
|
| 218 |
else:
|
| 219 |
+
messages = [SystemMessage(query_prompt(chat_model))] + state["messages"]
|
| 220 |
response = query_model.invoke(messages)
|
| 221 |
|
| 222 |
return {"messages": response}
|
|
|
|
| 228 |
# print_message_summaries(messages, "--- generate: before normalization ---")
|
| 229 |
messages = normalize_messages(messages)
|
| 230 |
# Add the system message here because we're not using tools
|
| 231 |
+
messages = [SystemMessage(generate_prompt(chat_model))] + messages
|
|
|
|
|
|
|
| 232 |
# print_message_summaries(messages, "--- generate: after normalization ---")
|
| 233 |
else:
|
| 234 |
+
messages = [
|
| 235 |
+
SystemMessage(generate_prompt(chat_model, with_tools=True))
|
| 236 |
+
] + state["messages"]
|
| 237 |
response = generate_model.invoke(messages)
|
| 238 |
|
| 239 |
return {"messages": response}
|
main.py
CHANGED
|
@@ -200,10 +200,8 @@ def RunChain(
|
|
| 200 |
# Get chat model (LLM)
|
| 201 |
chat_model = GetChatModel(compute_mode)
|
| 202 |
|
| 203 |
-
#
|
| 204 |
-
system_prompt = generate_prompt()
|
| 205 |
-
if hasattr(chat_model, "model_id") and not think:
|
| 206 |
-
system_prompt = f"/no_think\n{system_prompt}"
|
| 207 |
|
| 208 |
# Create a prompt template
|
| 209 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
|
|
|
| 200 |
# Get chat model (LLM)
|
| 201 |
chat_model = GetChatModel(compute_mode)
|
| 202 |
|
| 203 |
+
# Get prompt with /no_think for SmolLM3/Qwen
|
| 204 |
+
system_prompt = generate_prompt(chat_model)
|
|
|
|
|
|
|
| 205 |
|
| 206 |
# Create a prompt template
|
| 207 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
mods/tool_calling_llm.py
CHANGED
|
@@ -299,9 +299,7 @@ class ToolCallingLLM(BaseChatModel, ABC):
|
|
| 299 |
)
|
| 300 |
if called_tool is None:
|
| 301 |
# Issue a warning and return the generated content 20250727 jmd
|
| 302 |
-
warnings.warn(
|
| 303 |
-
f"Tool {called_tool} called from {self.model} output not in functions list"
|
| 304 |
-
)
|
| 305 |
return AIMessage(content=response_message.content)
|
| 306 |
|
| 307 |
# Get tool arguments from output
|
|
|
|
| 299 |
)
|
| 300 |
if called_tool is None:
|
| 301 |
# Issue a warning and return the generated content 20250727 jmd
|
| 302 |
+
warnings.warn(f"Called tool ({called_tool}) not in functions list")
|
|
|
|
|
|
|
| 303 |
return AIMessage(content=response_message.content)
|
| 304 |
|
| 305 |
# Get tool arguments from output
|
prompts.py
CHANGED
|
@@ -3,65 +3,73 @@ from util import get_sources, get_start_end_months
|
|
| 3 |
import re
|
| 4 |
|
| 5 |
|
| 6 |
-
def
|
| 7 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
"""
|
| 12 |
|
| 13 |
# Get start and end months from database
|
| 14 |
start, end = get_start_end_months(get_sources())
|
| 15 |
|
| 16 |
-
|
| 17 |
-
f"Today Date: {date.today()}."
|
| 18 |
-
"You are a helpful
|
| 19 |
-
"Write a search query to retrieve emails relevant to the user's question."
|
| 20 |
-
"Do not answer the user's question and do not ask the user for more information."
|
| 21 |
-
# gpt-4o-mini thinks last two months aren't available with this: "Emails from from {start} to {end} are available for retrieval."
|
| 22 |
-
f"The emails available for retrieval are from {start} to {end}."
|
| 23 |
-
"For questions about differences or comparison between X and Y, retrieve emails about X and Y."
|
| 24 |
-
"For general summaries, use retrieve_emails(search_query='R')."
|
| 25 |
-
"For specific questions, use retrieve_emails(search_query=<specific topic>)."
|
| 26 |
-
"For questions about years, use retrieve_emails(search_query=, start_year=, end_year=) (this month is this year)."
|
| 27 |
-
"For questions about months, use 3-letter abbreviations (Jan
|
| 28 |
-
"Even if retrieved emails are available, you should retrieve more emails to answer the most recent question." # Qwen
|
| 29 |
-
# "You must perform the search yourself. Do not tell the user how to retrieve emails." # Qwen
|
| 30 |
-
"Do not use your memory or knowledge to answer the user's question. Only retrieve emails based on the user's question." # Qwen
|
| 31 |
-
# "If you decide not to retrieve emails, tell the user why and suggest how to improve their question to chat with the R-help mailing list."
|
| 32 |
)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
if matches:
|
| 37 |
-
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
| 38 |
-
return query_prompt
|
| 39 |
|
| 40 |
|
| 41 |
-
def generate_prompt(
|
| 42 |
"""Return system prompt for generate step"""
|
| 43 |
-
|
| 44 |
-
f"Today Date: {date.today()}."
|
| 45 |
-
"You are a helpful chatbot designed to answer questions about R programming based on the R-help mailing list archives."
|
| 46 |
-
"Summarize the retrieved emails to answer the user's question or query."
|
| 47 |
-
"If any of the retrieved emails are irrelevant (e.g. wrong dates), then do not use them."
|
| 48 |
-
"Tell the user if there are no retrieved emails or if you are unable to answer the question based on the information in the emails."
|
| 49 |
-
"Do not give an answer based on your own knowledge or memory, and do not include examples that aren't based on the retrieved emails."
|
| 50 |
-
"Example: For a question about using lm(), take examples of lm() from the retrieved emails to answer the user's question."
|
| 51 |
-
# "Do not respond with packages that are only listed under sessionInfo, session info, or other attached packages."
|
| 52 |
-
"Summarize the content of the emails rather than copying the headers." # Qwen
|
| 53 |
-
"You must include inline citations (email senders and dates) in each part of your response."
|
| 54 |
-
"Only answer general questions about R if the answer is in the retrieved emails."
|
| 55 |
-
"Respond with 300 words maximum and 30 lines of code maximum and include any relevant URLs from the retrieved emails."
|
| 56 |
)
|
| 57 |
if with_tools:
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
return
|
| 65 |
|
| 66 |
|
| 67 |
# Prompt template for SmolLM3 with tools
|
|
|
|
| 3 |
import re
|
| 4 |
|
| 5 |
|
| 6 |
+
def check_prompt(prompt, chat_model, think):
|
| 7 |
+
"""Check for unassigned variables and add /no_think if needed"""
|
| 8 |
+
# A sanity check that we don't have unassigned variables
|
| 9 |
+
# (this causes KeyError in parsing by ToolCallingLLM)
|
| 10 |
+
matches = re.findall(r"\{.*?\}", " ".join(prompt))
|
| 11 |
+
if matches:
|
| 12 |
+
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
| 13 |
+
# Check if we should add /no_think to turn off thinking mode
|
| 14 |
+
if hasattr(chat_model, "model_id"):
|
| 15 |
+
model_id = chat_model.model_id
|
| 16 |
+
if ("SmolLM" in model_id or "Qwen" in model_id) and not think:
|
| 17 |
+
prompt = "/no_think\n" + prompt
|
| 18 |
+
return prompt
|
| 19 |
|
| 20 |
+
|
| 21 |
+
def query_prompt(chat_model, think=False):
|
| 22 |
+
"""Return system prompt for query step"""
|
| 23 |
|
| 24 |
# Get start and end months from database
|
| 25 |
start, end = get_start_end_months(get_sources())
|
| 26 |
|
| 27 |
+
prompt = (
|
| 28 |
+
f"Today Date: {date.today()}. "
|
| 29 |
+
"You are a helpful assistant designed to get information about R programming from the R-help mailing list archives. "
|
| 30 |
+
"Write a search query to retrieve emails relevant to the user's question. "
|
| 31 |
+
"Do not answer the user's question and do not ask the user for more information. "
|
| 32 |
+
# gpt-4o-mini thinks last two months aren't available with this: "Emails from from {start} to {end} are available for retrieval. "
|
| 33 |
+
f"The emails available for retrieval are from {start} to {end}. "
|
| 34 |
+
"For questions about differences or comparison between X and Y, retrieve emails about X and Y. "
|
| 35 |
+
"For general summaries, use retrieve_emails(search_query='R'). "
|
| 36 |
+
"For specific questions, use retrieve_emails(search_query=<specific topic>). "
|
| 37 |
+
"For questions about years, use retrieve_emails(search_query=, start_year=, end_year=) (this month is this year). "
|
| 38 |
+
"For questions about months, use 3-letter abbreviations (Jan...Dec) for the 'month' argument. "
|
| 39 |
+
"Even if retrieved emails are available, you should retrieve more emails to answer the most recent question. " # Qwen
|
| 40 |
+
# "You must perform the search yourself. Do not tell the user how to retrieve emails. " # Qwen
|
| 41 |
+
"Do not use your memory or knowledge to answer the user's question. Only retrieve emails based on the user's question. " # Qwen
|
| 42 |
+
# "If you decide not to retrieve emails, tell the user why and suggest how to improve their question to chat with the R-help mailing list. "
|
| 43 |
)
|
| 44 |
+
prompt = check_prompt(prompt, chat_model, think)
|
| 45 |
+
|
| 46 |
+
return prompt
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
+
def generate_prompt(chat_model, think=False, with_tools=False):
|
| 50 |
"""Return system prompt for generate step"""
|
| 51 |
+
prompt = (
|
| 52 |
+
f"Today Date: {date.today()}. "
|
| 53 |
+
"You are a helpful chatbot designed to answer questions about R programming based on the R-help mailing list archives. "
|
| 54 |
+
"Summarize the retrieved emails to answer the user's question or query. "
|
| 55 |
+
"If any of the retrieved emails are irrelevant (e.g. wrong dates), then do not use them. "
|
| 56 |
+
"Tell the user if there are no retrieved emails or if you are unable to answer the question based on the information in the emails. "
|
| 57 |
+
"Do not give an answer based on your own knowledge or memory, and do not include examples that aren't based on the retrieved emails. "
|
| 58 |
+
"Example: For a question about using lm(), take examples of lm() from the retrieved emails to answer the user's question. "
|
| 59 |
+
# "Do not respond with packages that are only listed under sessionInfo, session info, or other attached packages. "
|
| 60 |
+
"Summarize the content of the emails rather than copying the headers. " # Qwen
|
| 61 |
+
"You must include inline citations (email senders and dates) in each part of your response. "
|
| 62 |
+
"Only answer general questions about R if the answer is in the retrieved emails. "
|
| 63 |
+
"Respond with 300 words maximum and 30 lines of code maximum and include any relevant URLs from the retrieved emails. "
|
| 64 |
)
|
| 65 |
if with_tools:
|
| 66 |
+
prompt = (
|
| 67 |
+
f"{prompt}"
|
| 68 |
+
"Use answer_with_citations to provide the complete answer and all citations used. "
|
| 69 |
+
)
|
| 70 |
+
prompt = check_prompt(prompt, chat_model, think)
|
| 71 |
+
|
| 72 |
+
return prompt
|
| 73 |
|
| 74 |
|
| 75 |
# Prompt template for SmolLM3 with tools
|