Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
Β·
3575a77
1
Parent(s):
ace4242
Enable thinking for answer
Browse files- app.py +9 -8
- graph.py +17 -13
- main.py +8 -8
- mods/tool_calling_llm.py +11 -3
- prompts.py +2 -2
app.py
CHANGED
|
@@ -88,7 +88,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
|
|
| 88 |
# Get the chat model and build the graph
|
| 89 |
chat_model = GetChatModel(compute_mode)
|
| 90 |
graph_builder = BuildGraph(
|
| 91 |
-
chat_model, compute_mode, search_type,
|
| 92 |
)
|
| 93 |
# Compile the graph with an in-memory checkpointer
|
| 94 |
memory = MemorySaver()
|
|
@@ -184,7 +184,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
|
|
| 184 |
retrieved_emails = "\n\n".join(retrieved_emails)
|
| 185 |
yield history, retrieved_emails, []
|
| 186 |
|
| 187 |
-
if node == "
|
| 188 |
# Append messages (thinking and non-thinking) to history
|
| 189 |
chunk_messages = chunk["messages"]
|
| 190 |
history = append_content(chunk_messages, history, thinking_about="answer")
|
|
@@ -383,8 +383,9 @@ with gr.Blocks(
|
|
| 383 |
status_text = f"""
|
| 384 |
π Now in **local** mode, using ZeroGPU hardware<br>
|
| 385 |
β Response time is about one minute<br>
|
| 386 |
-
|
| 387 |
-
 
|
|
|
|
| 388 |
β¨ [nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) and [{model_id.split("/")[-1]}](https://huggingface.co/{model_id})<br>
|
| 389 |
π See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
|
| 390 |
"""
|
|
@@ -412,15 +413,15 @@ with gr.Blocks(
|
|
| 412 |
"""Get example questions based on compute mode"""
|
| 413 |
questions = [
|
| 414 |
# "What is today's date?",
|
| 415 |
-
"Summarize emails from the last two months",
|
| 416 |
-
"
|
| 417 |
"When was has.HLC mentioned?",
|
| 418 |
"Who reported installation problems in 2023-2024?",
|
| 419 |
]
|
| 420 |
|
| 421 |
if compute_mode == "remote":
|
| 422 |
-
# Remove "/
|
| 423 |
-
questions = [q.replace(" /
|
| 424 |
|
| 425 |
# cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
|
| 426 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
|
|
|
| 88 |
# Get the chat model and build the graph
|
| 89 |
chat_model = GetChatModel(compute_mode)
|
| 90 |
graph_builder = BuildGraph(
|
| 91 |
+
chat_model, compute_mode, search_type, think_answer=True
|
| 92 |
)
|
| 93 |
# Compile the graph with an in-memory checkpointer
|
| 94 |
memory = MemorySaver()
|
|
|
|
| 184 |
retrieved_emails = "\n\n".join(retrieved_emails)
|
| 185 |
yield history, retrieved_emails, []
|
| 186 |
|
| 187 |
+
if node == "answer":
|
| 188 |
# Append messages (thinking and non-thinking) to history
|
| 189 |
chunk_messages = chunk["messages"]
|
| 190 |
history = append_content(chunk_messages, history, thinking_about="answer")
|
|
|
|
| 383 |
status_text = f"""
|
| 384 |
π Now in **local** mode, using ZeroGPU hardware<br>
|
| 385 |
β Response time is about one minute<br>
|
| 386 |
+
π§ Thinking is enabled for the answer<br>
|
| 387 |
+
  π Add **/think** to enable thinking for the query</br>
|
| 388 |
+
  π« Add **/no_think** to disable all thinking</br>
|
| 389 |
β¨ [nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) and [{model_id.split("/")[-1]}](https://huggingface.co/{model_id})<br>
|
| 390 |
π See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
|
| 391 |
"""
|
|
|
|
| 413 |
"""Get example questions based on compute mode"""
|
| 414 |
questions = [
|
| 415 |
# "What is today's date?",
|
| 416 |
+
"Summarize emails from the last two months /no_think",
|
| 417 |
+
"Show me code examples using plotmath",
|
| 418 |
"When was has.HLC mentioned?",
|
| 419 |
"Who reported installation problems in 2023-2024?",
|
| 420 |
]
|
| 421 |
|
| 422 |
if compute_mode == "remote":
|
| 423 |
+
# Remove "/no_think" from questions in remote mode
|
| 424 |
+
questions = [q.replace(" /no_think", "") for q in questions]
|
| 425 |
|
| 426 |
# cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
|
| 427 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
graph.py
CHANGED
|
@@ -9,7 +9,7 @@ import os
|
|
| 9 |
|
| 10 |
# Local modules
|
| 11 |
from retriever import BuildRetriever
|
| 12 |
-
from prompts import query_prompt,
|
| 13 |
from mods.tool_calling_llm import ToolCallingLLM
|
| 14 |
|
| 15 |
# For tracing (disabled)
|
|
@@ -94,6 +94,7 @@ def BuildGraph(
|
|
| 94 |
search_type,
|
| 95 |
top_k=6,
|
| 96 |
think_query=False,
|
|
|
|
| 97 |
):
|
| 98 |
"""
|
| 99 |
Build conversational RAG graph for email retrieval and answering with citations.
|
|
@@ -103,7 +104,8 @@ def BuildGraph(
|
|
| 103 |
compute_mode: remote or local (for retriever)
|
| 104 |
search_type: dense, sparse, or hybrid (for retriever)
|
| 105 |
top_k: number of documents to retrieve
|
| 106 |
-
think_query: Whether to use thinking mode for query
|
|
|
|
| 107 |
|
| 108 |
Based on:
|
| 109 |
https://python.langchain.com/docs/how_to/qa_sources
|
|
@@ -193,11 +195,11 @@ def BuildGraph(
|
|
| 193 |
chat_model, query_prompt(chat_model, think=think_query)
|
| 194 |
).bind_tools([retrieve_emails])
|
| 195 |
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
| 196 |
-
|
| 197 |
else:
|
| 198 |
# For remote model (OpenAI API)
|
| 199 |
query_model = chat_model.bind_tools([retrieve_emails])
|
| 200 |
-
|
| 201 |
|
| 202 |
# Initialize the graph object
|
| 203 |
graph = StateGraph(MessagesState)
|
|
@@ -216,27 +218,29 @@ def BuildGraph(
|
|
| 216 |
|
| 217 |
return {"messages": response}
|
| 218 |
|
| 219 |
-
def
|
| 220 |
"""Generates an answer with the chat model"""
|
| 221 |
if is_local:
|
| 222 |
messages = state["messages"]
|
| 223 |
-
# print_message_summaries(messages, "---
|
| 224 |
messages = normalize_messages(messages)
|
| 225 |
# Add the system message here because we're not using tools
|
| 226 |
-
messages = [
|
| 227 |
-
|
|
|
|
|
|
|
| 228 |
else:
|
| 229 |
messages = [
|
| 230 |
-
SystemMessage(
|
| 231 |
] + state["messages"]
|
| 232 |
-
response =
|
| 233 |
|
| 234 |
return {"messages": response}
|
| 235 |
|
| 236 |
# Define model and tool nodes
|
| 237 |
graph.add_node("query", query)
|
| 238 |
-
graph.add_node("generate", generate)
|
| 239 |
graph.add_node("retrieve_emails", ToolNode([retrieve_emails]))
|
|
|
|
| 240 |
graph.add_node("answer_with_citations", ToolNode([answer_with_citations]))
|
| 241 |
|
| 242 |
# Route the user's input to the query model
|
|
@@ -249,13 +253,13 @@ def BuildGraph(
|
|
| 249 |
{END: END, "tools": "retrieve_emails"},
|
| 250 |
)
|
| 251 |
graph.add_conditional_edges(
|
| 252 |
-
"
|
| 253 |
tools_condition,
|
| 254 |
{END: END, "tools": "answer_with_citations"},
|
| 255 |
)
|
| 256 |
|
| 257 |
# Add edge from the retrieval tool to the generating model
|
| 258 |
-
graph.add_edge("retrieve_emails", "
|
| 259 |
|
| 260 |
# Done!
|
| 261 |
return graph
|
|
|
|
| 9 |
|
| 10 |
# Local modules
|
| 11 |
from retriever import BuildRetriever
|
| 12 |
+
from prompts import query_prompt, answer_prompt, generic_tools_template
|
| 13 |
from mods.tool_calling_llm import ToolCallingLLM
|
| 14 |
|
| 15 |
# For tracing (disabled)
|
|
|
|
| 94 |
search_type,
|
| 95 |
top_k=6,
|
| 96 |
think_query=False,
|
| 97 |
+
think_answer=False,
|
| 98 |
):
|
| 99 |
"""
|
| 100 |
Build conversational RAG graph for email retrieval and answering with citations.
|
|
|
|
| 104 |
compute_mode: remote or local (for retriever)
|
| 105 |
search_type: dense, sparse, or hybrid (for retriever)
|
| 106 |
top_k: number of documents to retrieve
|
| 107 |
+
think_query: Whether to use thinking mode for the query
|
| 108 |
+
think_answer: Whether to use thinking mode for the answer
|
| 109 |
|
| 110 |
Based on:
|
| 111 |
https://python.langchain.com/docs/how_to/qa_sources
|
|
|
|
| 195 |
chat_model, query_prompt(chat_model, think=think_query)
|
| 196 |
).bind_tools([retrieve_emails])
|
| 197 |
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
| 198 |
+
answer_model = chat_model
|
| 199 |
else:
|
| 200 |
# For remote model (OpenAI API)
|
| 201 |
query_model = chat_model.bind_tools([retrieve_emails])
|
| 202 |
+
answer_model = chat_model.bind_tools([answer_with_citations])
|
| 203 |
|
| 204 |
# Initialize the graph object
|
| 205 |
graph = StateGraph(MessagesState)
|
|
|
|
| 218 |
|
| 219 |
return {"messages": response}
|
| 220 |
|
| 221 |
+
def answer(state: MessagesState):
|
| 222 |
"""Generates an answer with the chat model"""
|
| 223 |
if is_local:
|
| 224 |
messages = state["messages"]
|
| 225 |
+
# print_message_summaries(messages, "--- answer: before normalization ---")
|
| 226 |
messages = normalize_messages(messages)
|
| 227 |
# Add the system message here because we're not using tools
|
| 228 |
+
messages = [
|
| 229 |
+
SystemMessage(answer_prompt(chat_model, think=think_answer))
|
| 230 |
+
] + messages
|
| 231 |
+
# print_message_summaries(messages, "--- answer: after normalization ---")
|
| 232 |
else:
|
| 233 |
messages = [
|
| 234 |
+
SystemMessage(answer_prompt(chat_model, with_tools=True))
|
| 235 |
] + state["messages"]
|
| 236 |
+
response = answer_model.invoke(messages)
|
| 237 |
|
| 238 |
return {"messages": response}
|
| 239 |
|
| 240 |
# Define model and tool nodes
|
| 241 |
graph.add_node("query", query)
|
|
|
|
| 242 |
graph.add_node("retrieve_emails", ToolNode([retrieve_emails]))
|
| 243 |
+
graph.add_node("answer", answer)
|
| 244 |
graph.add_node("answer_with_citations", ToolNode([answer_with_citations]))
|
| 245 |
|
| 246 |
# Route the user's input to the query model
|
|
|
|
| 253 |
{END: END, "tools": "retrieve_emails"},
|
| 254 |
)
|
| 255 |
graph.add_conditional_edges(
|
| 256 |
+
"answer",
|
| 257 |
tools_condition,
|
| 258 |
{END: END, "tools": "answer_with_citations"},
|
| 259 |
)
|
| 260 |
|
| 261 |
# Add edge from the retrieval tool to the generating model
|
| 262 |
+
graph.add_edge("retrieve_emails", "answer")
|
| 263 |
|
| 264 |
# Done!
|
| 265 |
return graph
|
main.py
CHANGED
|
@@ -23,7 +23,7 @@ from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
|
|
| 23 |
from index import ProcessFile
|
| 24 |
from retriever import BuildRetriever, db_dir
|
| 25 |
from graph import BuildGraph
|
| 26 |
-
from prompts import
|
| 27 |
|
| 28 |
# -----------
|
| 29 |
# R-help-chat
|
|
@@ -201,7 +201,7 @@ def RunChain(
|
|
| 201 |
chat_model = GetChatModel(compute_mode)
|
| 202 |
|
| 203 |
# Get prompt with /no_think for SmolLM3/Qwen
|
| 204 |
-
system_prompt =
|
| 205 |
|
| 206 |
# Create a prompt template
|
| 207 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
|
@@ -236,8 +236,8 @@ def RunGraph(
|
|
| 236 |
compute_mode: str = "remote",
|
| 237 |
search_type: str = "hybrid",
|
| 238 |
top_k: int = 6,
|
| 239 |
-
|
| 240 |
-
|
| 241 |
thread_id=None,
|
| 242 |
):
|
| 243 |
"""Run graph for conversational RAG app
|
|
@@ -247,8 +247,8 @@ def RunGraph(
|
|
| 247 |
compute_mode: Compute mode for embedding and chat models (remote or local)
|
| 248 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
| 249 |
top_k: Number of documents to retrieve
|
| 250 |
-
|
| 251 |
-
|
| 252 |
thread_id: Thread ID for memory (optional)
|
| 253 |
|
| 254 |
Example:
|
|
@@ -263,8 +263,8 @@ def RunGraph(
|
|
| 263 |
compute_mode,
|
| 264 |
search_type,
|
| 265 |
top_k,
|
| 266 |
-
|
| 267 |
-
|
| 268 |
)
|
| 269 |
|
| 270 |
# Compile the graph with an in-memory checkpointer
|
|
|
|
| 23 |
from index import ProcessFile
|
| 24 |
from retriever import BuildRetriever, db_dir
|
| 25 |
from graph import BuildGraph
|
| 26 |
+
from prompts import answer_prompt
|
| 27 |
|
| 28 |
# -----------
|
| 29 |
# R-help-chat
|
|
|
|
| 201 |
chat_model = GetChatModel(compute_mode)
|
| 202 |
|
| 203 |
# Get prompt with /no_think for SmolLM3/Qwen
|
| 204 |
+
system_prompt = answer_prompt(chat_model)
|
| 205 |
|
| 206 |
# Create a prompt template
|
| 207 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
|
|
|
| 236 |
compute_mode: str = "remote",
|
| 237 |
search_type: str = "hybrid",
|
| 238 |
top_k: int = 6,
|
| 239 |
+
think_query=False,
|
| 240 |
+
think_answer=False,
|
| 241 |
thread_id=None,
|
| 242 |
):
|
| 243 |
"""Run graph for conversational RAG app
|
|
|
|
| 247 |
compute_mode: Compute mode for embedding and chat models (remote or local)
|
| 248 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
| 249 |
top_k: Number of documents to retrieve
|
| 250 |
+
think_query: Whether to use thinking mode for the query
|
| 251 |
+
think_answer: Whether to use thinking mode for the answer
|
| 252 |
thread_id: Thread ID for memory (optional)
|
| 253 |
|
| 254 |
Example:
|
|
|
|
| 263 |
compute_mode,
|
| 264 |
search_type,
|
| 265 |
top_k,
|
| 266 |
+
think_query,
|
| 267 |
+
think_answer,
|
| 268 |
)
|
| 269 |
|
| 270 |
# Compile the graph with an in-memory checkpointer
|
mods/tool_calling_llm.py
CHANGED
|
@@ -183,10 +183,18 @@ class ToolCallingLLM(BaseChatModel, ABC):
|
|
| 183 |
|
| 184 |
# Parse output for JSON (support multiple objects separated by commas)
|
| 185 |
try:
|
|
|
|
| 186 |
parsed_json_results = json.loads(f"[{post_think}]")
|
| 187 |
-
except
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
tool_calls = []
|
| 192 |
for parsed_json_result in parsed_json_results:
|
|
|
|
| 183 |
|
| 184 |
# Parse output for JSON (support multiple objects separated by commas)
|
| 185 |
try:
|
| 186 |
+
# Works for one or more JSON objects not enclosed in "[]"
|
| 187 |
parsed_json_results = json.loads(f"[{post_think}]")
|
| 188 |
+
except:
|
| 189 |
+
try:
|
| 190 |
+
# Works for one or more JSON objects already enclosed in "[]"
|
| 191 |
+
parsed_json_results = json.loads(f"{post_think}")
|
| 192 |
+
except json.JSONDecodeError:
|
| 193 |
+
# Return entire response if JSON wasn't parsed (or is missing)
|
| 194 |
+
return AIMessage(content=response_message.content)
|
| 195 |
+
|
| 196 |
+
# print("parsed_json_results")
|
| 197 |
+
# print(parsed_json_results)
|
| 198 |
|
| 199 |
tool_calls = []
|
| 200 |
for parsed_json_result in parsed_json_results:
|
prompts.py
CHANGED
|
@@ -46,8 +46,8 @@ def query_prompt(chat_model, think=False):
|
|
| 46 |
return prompt
|
| 47 |
|
| 48 |
|
| 49 |
-
def
|
| 50 |
-
"""Return system prompt for
|
| 51 |
prompt = (
|
| 52 |
f"Today Date: {date.today()}. "
|
| 53 |
"You are a helpful chatbot designed to answer questions about R programming based on the R-help mailing list archives. "
|
|
|
|
| 46 |
return prompt
|
| 47 |
|
| 48 |
|
| 49 |
+
def answer_prompt(chat_model, think=False, with_tools=False):
|
| 50 |
+
"""Return system prompt for answer step"""
|
| 51 |
prompt = (
|
| 52 |
f"Today Date: {date.today()}. "
|
| 53 |
"You are a helpful chatbot designed to answer questions about R programming based on the R-help mailing list archives. "
|