Spaces:
Paused
Paused
Review mode
Browse files- README.md +2 -0
- copywriter.py +37 -5
- requirements.txt +5 -1
- search_agent.py +10 -8
- search_agent_ui.py +80 -18
- web_crawler.py +2 -1
- web_rag.py +76 -37
README.md
CHANGED
|
@@ -10,6 +10,8 @@ pinned: false
|
|
| 10 |
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
|
|
|
|
|
|
| 13 |
# Simple Search Agent
|
| 14 |
|
| 15 |
This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information.
|
|
|
|
| 10 |
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
⚠️ **This project is a demonstration / proof-of-concept and is not intended for use in production environments. It is provided as-is, without warranty or guarantee of any kind. The code and any accompanying materials are for educational, testing, or evaluation purposes only.**⚠️
|
| 14 |
+
|
| 15 |
# Simple Search Agent
|
| 16 |
|
| 17 |
This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information.
|
copywriter.py
CHANGED
|
@@ -7,7 +7,6 @@ from langchain.prompts.chat import (
|
|
| 7 |
from langchain.prompts.prompt import PromptTemplate
|
| 8 |
|
| 9 |
|
| 10 |
-
|
| 11 |
def get_comments_prompt(query, draft):
|
| 12 |
system_message = SystemMessage(
|
| 13 |
content="""
|
|
@@ -35,14 +34,11 @@ def get_comments_prompt(query, draft):
|
|
| 35 |
)
|
| 36 |
return [system_message, human_message]
|
| 37 |
|
| 38 |
-
|
| 39 |
def generate_comments(chat_llm, query, draft, callbacks=[]):
|
| 40 |
messages = get_comments_prompt(query, draft)
|
| 41 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
| 42 |
return response.content
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
def get_final_text_prompt(query, draft, comments):
|
| 47 |
system_message = SystemMessage(
|
| 48 |
content="""
|
|
@@ -74,4 +70,40 @@ def get_final_text_prompt(query, draft, comments):
|
|
| 74 |
def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
|
| 75 |
messages = get_final_text_prompt(query, draft, comments)
|
| 76 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
| 77 |
-
return response.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from langchain.prompts.prompt import PromptTemplate
|
| 8 |
|
| 9 |
|
|
|
|
| 10 |
def get_comments_prompt(query, draft):
|
| 11 |
system_message = SystemMessage(
|
| 12 |
content="""
|
|
|
|
| 34 |
)
|
| 35 |
return [system_message, human_message]
|
| 36 |
|
|
|
|
| 37 |
def generate_comments(chat_llm, query, draft, callbacks=[]):
|
| 38 |
messages = get_comments_prompt(query, draft)
|
| 39 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
| 40 |
return response.content
|
| 41 |
|
|
|
|
|
|
|
| 42 |
def get_final_text_prompt(query, draft, comments):
|
| 43 |
system_message = SystemMessage(
|
| 44 |
content="""
|
|
|
|
| 70 |
def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
|
| 71 |
messages = get_final_text_prompt(query, draft, comments)
|
| 72 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
| 73 |
+
return response.content
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_compare_texts_prompts(query, draft_text, final_text):
|
| 77 |
+
system_message = SystemMessage(
|
| 78 |
+
content="""
|
| 79 |
+
I want you to act as a writing quality evaluator.
|
| 80 |
+
I will provide you with the original user request and four texts.
|
| 81 |
+
Your task is to carefully analyze, compare the two texts across the following dimensions and grade each text 0 to 10:
|
| 82 |
+
1. Grammar and spelling - Which text has fewer grammatical errors and spelling mistakes?
|
| 83 |
+
2. Clarity and coherence - Which text is easier to understand and has a more logical flow of ideas? Evaluate how well each text conveys its main points.
|
| 84 |
+
3. Tone and style - Which text has a more appropriate and engaging tone and writing style for its intended purpose and audience?
|
| 85 |
+
4. Sticking to the request - Which text is more successful responding to the original user request. Consider the request, the style, the length, etc.
|
| 86 |
+
5. Overall effectiveness - Considering the above factors, which text is more successful overall at communicating its message and achieving its goals?
|
| 87 |
+
|
| 88 |
+
After comparing the texts on these criteria, clearly state which text you think is better and summarize the main reasons why.
|
| 89 |
+
Provide specific examples from each text to support your evaluation.
|
| 90 |
+
"""
|
| 91 |
+
)
|
| 92 |
+
human_message = HumanMessage(
|
| 93 |
+
content=f"""
|
| 94 |
+
Original query: {query}
|
| 95 |
+
------------------------
|
| 96 |
+
Text 1: {draft_text}
|
| 97 |
+
------------------------
|
| 98 |
+
Text 2: {final_text}
|
| 99 |
+
------------------------
|
| 100 |
+
Summary:
|
| 101 |
+
"""
|
| 102 |
+
)
|
| 103 |
+
return [system_message, human_message]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def compare_text(chat_llm, query, draft, final, callbacks=[]):
|
| 107 |
+
messages = get_compare_texts_prompts(query, draft_text=draft, final_text=final)
|
| 108 |
+
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
| 109 |
+
return response.content
|
requirements.txt
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
|
|
| 1 |
boto3
|
| 2 |
bs4
|
|
|
|
| 3 |
cohere
|
| 4 |
docopt
|
| 5 |
faiss-cpu
|
|
@@ -7,7 +9,7 @@ google-api-python-client
|
|
| 7 |
pdfplumber
|
| 8 |
python-dotenv
|
| 9 |
langchain
|
| 10 |
-
langchain-
|
| 11 |
langchain-fireworks
|
| 12 |
langchain_core
|
| 13 |
langchain_community
|
|
@@ -18,6 +20,8 @@ langsmith
|
|
| 18 |
schema
|
| 19 |
streamlit
|
| 20 |
selenium
|
|
|
|
|
|
|
| 21 |
rich
|
| 22 |
trafilatura
|
| 23 |
watchdog
|
|
|
|
| 1 |
+
anthropic
|
| 2 |
boto3
|
| 3 |
bs4
|
| 4 |
+
chromedriver-py
|
| 5 |
cohere
|
| 6 |
docopt
|
| 7 |
faiss-cpu
|
|
|
|
| 9 |
pdfplumber
|
| 10 |
python-dotenv
|
| 11 |
langchain
|
| 12 |
+
langchain-aws
|
| 13 |
langchain-fireworks
|
| 14 |
langchain_core
|
| 15 |
langchain_community
|
|
|
|
| 20 |
schema
|
| 21 |
streamlit
|
| 22 |
selenium
|
| 23 |
+
tiktoken
|
| 24 |
+
transformers
|
| 25 |
rich
|
| 26 |
trafilatura
|
| 27 |
watchdog
|
search_agent.py
CHANGED
|
@@ -8,6 +8,7 @@ Usage:
|
|
| 8 |
[--temperature=temp]
|
| 9 |
[--copywrite]
|
| 10 |
[--max_pages=num]
|
|
|
|
| 11 |
[--output=text]
|
| 12 |
SEARCH_QUERY
|
| 13 |
search_agent.py --version
|
|
@@ -21,6 +22,7 @@ Options:
|
|
| 21 |
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
|
| 22 |
-m model --model=model Use a specific model
|
| 23 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
|
|
|
| 24 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
| 25 |
|
| 26 |
"""
|
|
@@ -63,8 +65,6 @@ def get_selenium_driver():
|
|
| 63 |
driver = webdriver.Chrome(options=chrome_options)
|
| 64 |
return driver
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
callbacks = []
|
| 69 |
if os.getenv("LANGCHAIN_API_KEY"):
|
| 70 |
callbacks.append(
|
|
@@ -90,14 +90,16 @@ if __name__ == '__main__':
|
|
| 90 |
temperature = float(arguments["--temperature"])
|
| 91 |
domain=arguments["--domain"]
|
| 92 |
max_pages=arguments["--max_pages"]
|
|
|
|
| 93 |
output=arguments["--output"]
|
| 94 |
query = arguments["SEARCH_QUERY"]
|
| 95 |
|
| 96 |
chat, embedding_model = wr.get_models(provider, model, temperature)
|
| 97 |
-
#console.log(f"Using {chat.model_name} on {provider}")
|
| 98 |
|
| 99 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
| 100 |
optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
|
|
|
|
|
|
|
| 101 |
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
|
| 102 |
|
| 103 |
with console.status(
|
|
@@ -112,11 +114,11 @@ if __name__ == '__main__':
|
|
| 112 |
contents = wc.get_links_contents(sources, get_selenium_driver)
|
| 113 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
| 114 |
|
| 115 |
-
with console.status(f"[bold green]
|
| 116 |
vector_store = wc.vectorize(contents, embedding_model)
|
| 117 |
|
| 118 |
-
with console.status("[bold green]
|
| 119 |
-
draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k =
|
| 120 |
|
| 121 |
console.rule(f"[bold green]Response from {provider}")
|
| 122 |
if output == "text":
|
|
@@ -129,7 +131,7 @@ if __name__ == '__main__':
|
|
| 129 |
with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
|
| 130 |
comments = cw.generate_comments(chat, query, draft, callbacks=callbacks)
|
| 131 |
|
| 132 |
-
console.rule(
|
| 133 |
if output == "text":
|
| 134 |
console.print(comments)
|
| 135 |
else:
|
|
@@ -139,7 +141,7 @@ if __name__ == '__main__':
|
|
| 139 |
with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
|
| 140 |
final_text = cw.generate_final_text(chat, query, draft, comments, callbacks=callbacks)
|
| 141 |
|
| 142 |
-
console.rule(
|
| 143 |
if output == "text":
|
| 144 |
console.print(final_text)
|
| 145 |
else:
|
|
|
|
| 8 |
[--temperature=temp]
|
| 9 |
[--copywrite]
|
| 10 |
[--max_pages=num]
|
| 11 |
+
[--max_extracts=num]
|
| 12 |
[--output=text]
|
| 13 |
SEARCH_QUERY
|
| 14 |
search_agent.py --version
|
|
|
|
| 22 |
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
|
| 23 |
-m model --model=model Use a specific model
|
| 24 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
| 25 |
+
-e num --max_extracts=num Max number of page extract to consider [default: 5]
|
| 26 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
| 27 |
|
| 28 |
"""
|
|
|
|
| 65 |
driver = webdriver.Chrome(options=chrome_options)
|
| 66 |
return driver
|
| 67 |
|
|
|
|
|
|
|
| 68 |
callbacks = []
|
| 69 |
if os.getenv("LANGCHAIN_API_KEY"):
|
| 70 |
callbacks.append(
|
|
|
|
| 90 |
temperature = float(arguments["--temperature"])
|
| 91 |
domain=arguments["--domain"]
|
| 92 |
max_pages=arguments["--max_pages"]
|
| 93 |
+
max_extract=int(arguments["--max_extracts"])
|
| 94 |
output=arguments["--output"]
|
| 95 |
query = arguments["SEARCH_QUERY"]
|
| 96 |
|
| 97 |
chat, embedding_model = wr.get_models(provider, model, temperature)
|
|
|
|
| 98 |
|
| 99 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
| 100 |
optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
|
| 101 |
+
if len(optimize_search_query) < 3:
|
| 102 |
+
optimize_search_query = query
|
| 103 |
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
|
| 104 |
|
| 105 |
with console.status(
|
|
|
|
| 114 |
contents = wc.get_links_contents(sources, get_selenium_driver)
|
| 115 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
| 116 |
|
| 117 |
+
with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
|
| 118 |
vector_store = wc.vectorize(contents, embedding_model)
|
| 119 |
|
| 120 |
+
with console.status("[bold green]Writing content", spinner='dots8Bit'):
|
| 121 |
+
draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract, callbacks=callbacks)
|
| 122 |
|
| 123 |
console.rule(f"[bold green]Response from {provider}")
|
| 124 |
if output == "text":
|
|
|
|
| 131 |
with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
|
| 132 |
comments = cw.generate_comments(chat, query, draft, callbacks=callbacks)
|
| 133 |
|
| 134 |
+
console.rule("[bold green]Response from reviewer")
|
| 135 |
if output == "text":
|
| 136 |
console.print(comments)
|
| 137 |
else:
|
|
|
|
| 141 |
with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
|
| 142 |
final_text = cw.generate_final_text(chat, query, draft, comments, callbacks=callbacks)
|
| 143 |
|
| 144 |
+
console.rule("[bold green]Final text")
|
| 145 |
if output == "text":
|
| 146 |
console.print(final_text)
|
| 147 |
else:
|
search_agent_ui.py
CHANGED
|
@@ -10,6 +10,7 @@ from langsmith.client import Client
|
|
| 10 |
|
| 11 |
import web_rag as wr
|
| 12 |
import web_crawler as wc
|
|
|
|
| 13 |
|
| 14 |
dotenv.load_dotenv()
|
| 15 |
|
|
@@ -18,7 +19,6 @@ ls_tracer = LangChainTracer(
|
|
| 18 |
client=Client()
|
| 19 |
)
|
| 20 |
|
| 21 |
-
|
| 22 |
class StreamHandler(BaseCallbackHandler):
|
| 23 |
"""Stream handler that appends tokens to container."""
|
| 24 |
def __init__(self, container, initial_text=""):
|
|
@@ -28,11 +28,36 @@ class StreamHandler(BaseCallbackHandler):
|
|
| 28 |
def on_llm_new_token(self, token: str, **kwargs):
|
| 29 |
self.text += token
|
| 30 |
self.container.markdown(self.text)
|
|
|
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
st.title("🔍 Simple Search Agent 💬")
|
| 33 |
|
| 34 |
if "providers" not in st.session_state:
|
| 35 |
providers = []
|
|
|
|
|
|
|
| 36 |
if os.getenv("COHERE_API_KEY"):
|
| 37 |
providers.append("cohere")
|
| 38 |
if os.getenv("OPENAI_API_KEY"):
|
|
@@ -41,22 +66,34 @@ if "providers" not in st.session_state:
|
|
| 41 |
providers.append("groq")
|
| 42 |
if os.getenv("OLLAMA_API_KEY"):
|
| 43 |
providers.append("ollama")
|
| 44 |
-
if os.getenv("FIREWORKS_API_KEY"):
|
| 45 |
-
providers.append("fireworks")
|
| 46 |
if os.getenv("CREDENTIALS_PROFILE_NAME"):
|
| 47 |
providers.append("bedrock")
|
| 48 |
st.session_state["providers"] = providers
|
| 49 |
|
| 50 |
-
with st.sidebar:
|
| 51 |
-
st.
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
if "messages" not in st.session_state:
|
| 58 |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
| 59 |
-
|
| 60 |
for message in st.session_state.messages:
|
| 61 |
st.chat_message(message["role"]).write(message["content"])
|
| 62 |
if message["role"] == "assistant" and 'message_id' in message:
|
|
@@ -80,6 +117,7 @@ if prompt := st.chat_input("Enter you instructions..." ):
|
|
| 80 |
st.write(f"I should search the web for: {optimize_search_query}")
|
| 81 |
|
| 82 |
sources = wc.get_sources(optimize_search_query, max_pages=max_pages)
|
|
|
|
| 83 |
|
| 84 |
st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
|
| 85 |
contents = wc.get_links_contents(sources)
|
|
@@ -87,18 +125,42 @@ if prompt := st.chat_input("Enter you instructions..." ):
|
|
| 87 |
st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
|
| 88 |
vector_store = wc.vectorize(contents, embedding_model=embedding_model)
|
| 89 |
st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
rag_prompt = wr.build_rag_prompt(prompt, optimize_search_query, vector_store, top_k=5, callbacks=[ls_tracer])
|
| 92 |
with st.chat_message("assistant"):
|
| 93 |
st_cb = StreamHandler(st.empty())
|
| 94 |
result = chat.invoke(rag_prompt, stream=True, config={ "callbacks": [st_cb, ls_tracer]})
|
| 95 |
response = result.content.strip()
|
| 96 |
message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
| 97 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
import web_rag as wr
|
| 12 |
import web_crawler as wc
|
| 13 |
+
import copywriter as cw
|
| 14 |
|
| 15 |
dotenv.load_dotenv()
|
| 16 |
|
|
|
|
| 19 |
client=Client()
|
| 20 |
)
|
| 21 |
|
|
|
|
| 22 |
class StreamHandler(BaseCallbackHandler):
|
| 23 |
"""Stream handler that appends tokens to container."""
|
| 24 |
def __init__(self, container, initial_text=""):
|
|
|
|
| 28 |
def on_llm_new_token(self, token: str, **kwargs):
|
| 29 |
self.text += token
|
| 30 |
self.container.markdown(self.text)
|
| 31 |
+
|
| 32 |
|
| 33 |
+
def create_links_markdown(sources_list):
|
| 34 |
+
"""
|
| 35 |
+
Create a markdown string for each source in the provided JSON.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
sources_list (list): A list of dictionaries representing the sources.
|
| 39 |
+
Each dictionary should have 'title', 'link', and 'snippet' keys.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
str: A markdown string with a bullet point for each source,
|
| 43 |
+
including the title linked to the URL and the snippet.
|
| 44 |
+
"""
|
| 45 |
+
markdown_list = []
|
| 46 |
+
for source in sources_list:
|
| 47 |
+
title = source['title']
|
| 48 |
+
link = source['link']
|
| 49 |
+
snippet = source['snippet']
|
| 50 |
+
markdown = f"- [{title}]({link})\n {snippet}"
|
| 51 |
+
markdown_list.append(markdown)
|
| 52 |
+
return "\n".join(markdown_list)
|
| 53 |
+
|
| 54 |
+
st.set_page_config(layout="wide")
|
| 55 |
st.title("🔍 Simple Search Agent 💬")
|
| 56 |
|
| 57 |
if "providers" not in st.session_state:
|
| 58 |
providers = []
|
| 59 |
+
if os.getenv("FIREWORKS_API_KEY"):
|
| 60 |
+
providers.append("fireworks")
|
| 61 |
if os.getenv("COHERE_API_KEY"):
|
| 62 |
providers.append("cohere")
|
| 63 |
if os.getenv("OPENAI_API_KEY"):
|
|
|
|
| 66 |
providers.append("groq")
|
| 67 |
if os.getenv("OLLAMA_API_KEY"):
|
| 68 |
providers.append("ollama")
|
|
|
|
|
|
|
| 69 |
if os.getenv("CREDENTIALS_PROFILE_NAME"):
|
| 70 |
providers.append("bedrock")
|
| 71 |
st.session_state["providers"] = providers
|
| 72 |
|
| 73 |
+
with st.sidebar.expander("Options", expanded=False):
|
| 74 |
+
model_provider = st.selectbox("Model provider 🧠", st.session_state["providers"])
|
| 75 |
+
temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
|
| 76 |
+
max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 15, help="How many web pages to retrive from the internet")
|
| 77 |
+
top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 5, help="How many of the top extracts to consider")
|
| 78 |
+
reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a write, then comments and then rewrite")
|
| 79 |
+
|
| 80 |
+
with st.sidebar.expander("Links", expanded=False):
|
| 81 |
+
links_md = st.markdown("")
|
| 82 |
+
|
| 83 |
+
if reviewer_mode:
|
| 84 |
+
with st.sidebar.expander("Answer review", expanded=False):
|
| 85 |
+
st.caption("Draft")
|
| 86 |
+
draft_md = st.markdown("")
|
| 87 |
+
st.divider()
|
| 88 |
+
st.caption("Comments")
|
| 89 |
+
comments_md = st.markdown("")
|
| 90 |
+
st.divider()
|
| 91 |
+
st.caption("Comparaison")
|
| 92 |
+
comparaison_md = st.markdown("")
|
| 93 |
|
| 94 |
if "messages" not in st.session_state:
|
| 95 |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
| 96 |
+
|
| 97 |
for message in st.session_state.messages:
|
| 98 |
st.chat_message(message["role"]).write(message["content"])
|
| 99 |
if message["role"] == "assistant" and 'message_id' in message:
|
|
|
|
| 117 |
st.write(f"I should search the web for: {optimize_search_query}")
|
| 118 |
|
| 119 |
sources = wc.get_sources(optimize_search_query, max_pages=max_pages)
|
| 120 |
+
links_md.markdown(create_links_markdown(sources))
|
| 121 |
|
| 122 |
st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
|
| 123 |
contents = wc.get_links_contents(sources)
|
|
|
|
| 125 |
st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
|
| 126 |
vector_store = wc.vectorize(contents, embedding_model=embedding_model)
|
| 127 |
st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if reviewer_mode:
|
| 131 |
+
st.write("Creating a draft")
|
| 132 |
+
draft_prompt = wr.build_rag_prompt(
|
| 133 |
+
chat, prompt, optimize_search_query,
|
| 134 |
+
vector_store, top_k=top_k_documents, callbacks=[ls_tracer])
|
| 135 |
+
draft = chat.invoke(draft_prompt, stream=False, config={ "callbacks": [ls_tracer]})
|
| 136 |
+
draft_md.markdown(draft.content)
|
| 137 |
+
st.write("Sending draft for review")
|
| 138 |
+
comments = cw.generate_comments(chat, prompt, draft, callbacks=[ls_tracer])
|
| 139 |
+
comments_md.markdown(comments)
|
| 140 |
+
st.write("Reviewing comments and generating final answer")
|
| 141 |
+
rag_prompt = cw.get_final_text_prompt(prompt, draft, comments)
|
| 142 |
+
else:
|
| 143 |
+
rag_prompt = wr.build_rag_prompt(
|
| 144 |
+
chat, prompt, optimize_search_query, vector_store,
|
| 145 |
+
top_k=top_k_documents, callbacks=[ls_tracer]
|
| 146 |
+
)
|
| 147 |
|
|
|
|
| 148 |
with st.chat_message("assistant"):
|
| 149 |
st_cb = StreamHandler(st.empty())
|
| 150 |
result = chat.invoke(rag_prompt, stream=True, config={ "callbacks": [st_cb, ls_tracer]})
|
| 151 |
response = result.content.strip()
|
| 152 |
message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
| 153 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
| 154 |
+
|
| 155 |
+
if st.session_state.messages[-1]["role"] == "assistant":
|
| 156 |
+
st.download_button(
|
| 157 |
+
label="Download",
|
| 158 |
+
data=st.session_state.messages[-1]["content"],
|
| 159 |
+
file_name=f"{message_id}.txt",
|
| 160 |
+
mime="text/plain"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if reviewer_mode:
|
| 164 |
+
compare_prompt = cw.get_compare_texts_prompts(prompt, draft_text=draft, final_text=response)
|
| 165 |
+
result = chat.invoke(compare_prompt, stream=False, config={ "callbacks": [ls_tracer]})
|
| 166 |
+
comparaison_md.markdown(result.content)
|
web_crawler.py
CHANGED
|
@@ -35,12 +35,13 @@ def get_sources(query, max_pages=10, domain=None):
|
|
| 35 |
json_response = response.json()
|
| 36 |
|
| 37 |
if 'web' not in json_response or 'results' not in json_response['web']:
|
|
|
|
| 38 |
raise Exception('Invalid API response format')
|
| 39 |
|
| 40 |
final_results = [{
|
| 41 |
'title': result['title'],
|
| 42 |
'link': result['url'],
|
| 43 |
-
'snippet': result['description'],
|
| 44 |
'favicon': result.get('profile', {}).get('img', '')
|
| 45 |
} for result in json_response['web']['results']]
|
| 46 |
|
|
|
|
| 35 |
json_response = response.json()
|
| 36 |
|
| 37 |
if 'web' not in json_response or 'results' not in json_response['web']:
|
| 38 |
+
print(response.text)
|
| 39 |
raise Exception('Invalid API response format')
|
| 40 |
|
| 41 |
final_results = [{
|
| 42 |
'title': result['title'],
|
| 43 |
'link': result['url'],
|
| 44 |
+
'snippet': extract(result['description'], output_format='txt', include_tables=False, include_images=False, include_formatting=True),
|
| 45 |
'favicon': result.get('profile', {}).get('img', '')
|
| 46 |
} for result in json_response['web']['results']]
|
| 47 |
|
web_rag.py
CHANGED
|
@@ -28,13 +28,14 @@ from langchain.prompts.chat import (
|
|
| 28 |
from langchain.prompts.prompt import PromptTemplate
|
| 29 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 30 |
|
|
|
|
| 31 |
from langchain_cohere.chat_models import ChatCohere
|
| 32 |
from langchain_cohere.embeddings import CohereEmbeddings
|
| 33 |
from langchain_fireworks.chat_models import ChatFireworks
|
| 34 |
-
from langchain_groq import ChatGroq
|
|
|
|
| 35 |
from langchain_openai import ChatOpenAI
|
| 36 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
| 37 |
-
from langchain_community.chat_models.bedrock import BedrockChat
|
| 38 |
from langchain_community.embeddings.bedrock import BedrockEmbeddings
|
| 39 |
from langchain_community.chat_models.ollama import ChatOllama
|
| 40 |
|
|
@@ -44,15 +45,15 @@ def get_models(provider, model=None, temperature=0.0):
|
|
| 44 |
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
|
| 45 |
if model is None:
|
| 46 |
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
| 47 |
-
chat_llm =
|
| 48 |
credentials_profile_name=credentials_profile_name,
|
| 49 |
model_id=model,
|
| 50 |
-
model_kwargs={"temperature": temperature,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
)
|
| 52 |
-
#embedding_model = BedrockEmbeddings(
|
| 53 |
-
# model_id='cohere.embed-multilingual-v3',
|
| 54 |
-
# credentials_profile_name=credentials_profile_name
|
| 55 |
-
#)
|
| 56 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
| 57 |
case 'openai':
|
| 58 |
if model is None:
|
|
@@ -73,14 +74,17 @@ def get_models(provider, model=None, temperature=0.0):
|
|
| 73 |
if model is None:
|
| 74 |
model = 'command-r-plus'
|
| 75 |
chat_llm = ChatCohere(model=model, temperature=temperature)
|
| 76 |
-
embedding_model = CohereEmbeddings(model="embed-english-light-v3.0")
|
|
|
|
| 77 |
case 'fireworks':
|
| 78 |
if model is None:
|
| 79 |
-
model = 'accounts/fireworks/models/
|
| 80 |
-
|
|
|
|
| 81 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
| 82 |
case _:
|
| 83 |
raise ValueError(f"Unknown LLM provider {provider}")
|
|
|
|
| 84 |
return chat_llm, embedding_model
|
| 85 |
|
| 86 |
|
|
@@ -96,12 +100,13 @@ def get_optimized_search_messages(query):
|
|
| 96 |
"""
|
| 97 |
system_message = SystemMessage(
|
| 98 |
content="""
|
| 99 |
-
I want you to act as a prompt optimizer for web search.
|
|
|
|
| 100 |
To optimize the prompt:
|
| 101 |
-
Identify the key information being requested
|
| 102 |
-
Arrange the keywords into a concise search string
|
| 103 |
-
Keep it short, around 1 to 5 words total
|
| 104 |
-
Put the most important keywords first
|
| 105 |
|
| 106 |
Some tips and things to be sure to remove:
|
| 107 |
- Remove any conversational or instructional phrases
|
|
@@ -110,44 +115,44 @@ def get_optimized_search_messages(query):
|
|
| 110 |
- Remove style instructions (exmaple: "in the style of", engaging, short, long)
|
| 111 |
- Remove lenght instruction (example: essay, article, letter, etc)
|
| 112 |
|
| 113 |
-
|
| 114 |
|
| 115 |
Example:
|
| 116 |
Question: How do I bake chocolate chip cookies from scratch?
|
| 117 |
-
|
| 118 |
Example:
|
| 119 |
Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
|
| 120 |
-
|
| 121 |
Example:
|
| 122 |
Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
|
| 123 |
-
|
| 124 |
Example:
|
| 125 |
Question: Write an engaging LinkedIn post about Andrew Ng
|
| 126 |
-
|
| 127 |
Example:
|
| 128 |
Question: Write a short article about the solar system in the style of Carl Sagan
|
| 129 |
-
|
| 130 |
Example:
|
| 131 |
Question: Should I use Kubernetes? Answer in the style of Gilfoyle from the TV show Silicon Valley
|
| 132 |
-
|
| 133 |
Example:
|
| 134 |
Question: Biography of Napoleon. Include a table with the major events.
|
| 135 |
-
|
| 136 |
Example:
|
| 137 |
Question: Write a short article on the history of the United States. Include a table with the major events.
|
| 138 |
-
|
| 139 |
Example:
|
| 140 |
Question: Write a short article about the solar system in the style of donald trump
|
| 141 |
-
|
| 142 |
Exmaple:
|
| 143 |
Question: Write a short linkedin about how the "freakeconomics" book previsions didn't pan out
|
| 144 |
-
|
| 145 |
"""
|
| 146 |
)
|
| 147 |
human_message = HumanMessage(
|
| 148 |
content=f"""
|
| 149 |
Question: {query}
|
| 150 |
-
|
| 151 |
"""
|
| 152 |
)
|
| 153 |
return [system_message, human_message]
|
|
@@ -230,15 +235,49 @@ def multi_query_rag(chat_llm, question, search_query, vectorstore, callbacks = [
|
|
| 230 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
| 231 |
return response.content
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
| 242 |
-
prompt = build_rag_prompt(question, search_query, vectorstore, top_k=
|
| 243 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
| 244 |
-
return response.content
|
|
|
|
| 28 |
from langchain.prompts.prompt import PromptTemplate
|
| 29 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 30 |
|
| 31 |
+
from langchain_aws import ChatBedrock
|
| 32 |
from langchain_cohere.chat_models import ChatCohere
|
| 33 |
from langchain_cohere.embeddings import CohereEmbeddings
|
| 34 |
from langchain_fireworks.chat_models import ChatFireworks
|
| 35 |
+
#from langchain_groq import ChatGroq
|
| 36 |
+
from langchain_groq.chat_models import ChatGroq
|
| 37 |
from langchain_openai import ChatOpenAI
|
| 38 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
|
|
|
| 39 |
from langchain_community.embeddings.bedrock import BedrockEmbeddings
|
| 40 |
from langchain_community.chat_models.ollama import ChatOllama
|
| 41 |
|
|
|
|
| 45 |
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
|
| 46 |
if model is None:
|
| 47 |
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
| 48 |
+
chat_llm = ChatBedrock(
|
| 49 |
credentials_profile_name=credentials_profile_name,
|
| 50 |
model_id=model,
|
| 51 |
+
model_kwargs={"temperature": temperature, "max_tokens":4096 },
|
| 52 |
+
)
|
| 53 |
+
embedding_model = BedrockEmbeddings(
|
| 54 |
+
model_id='cohere.embed-multilingual-v3',
|
| 55 |
+
credentials_profile_name=credentials_profile_name
|
| 56 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
| 58 |
case 'openai':
|
| 59 |
if model is None:
|
|
|
|
| 74 |
if model is None:
|
| 75 |
model = 'command-r-plus'
|
| 76 |
chat_llm = ChatCohere(model=model, temperature=temperature)
|
| 77 |
+
#embedding_model = CohereEmbeddings(model="embed-english-light-v3.0")
|
| 78 |
+
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
| 79 |
case 'fireworks':
|
| 80 |
if model is None:
|
| 81 |
+
#model = 'accounts/fireworks/models/dbrx-instruct'
|
| 82 |
+
model = 'accounts/fireworks/models/llama-v3-70b-instruct'
|
| 83 |
+
chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=8192)
|
| 84 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
| 85 |
case _:
|
| 86 |
raise ValueError(f"Unknown LLM provider {provider}")
|
| 87 |
+
|
| 88 |
return chat_llm, embedding_model
|
| 89 |
|
| 90 |
|
|
|
|
| 100 |
"""
|
| 101 |
system_message = SystemMessage(
|
| 102 |
content="""
|
| 103 |
+
I want you to act as a prompt optimizer for web search.
|
| 104 |
+
I will provide you with a chat prompt, and your goal is to optimize it into a search string that will yield the most relevant and useful information from a search engine like Google.
|
| 105 |
To optimize the prompt:
|
| 106 |
+
- Identify the key information being requested
|
| 107 |
+
- Arrange the keywords into a concise search string
|
| 108 |
+
- Keep it short, around 1 to 5 words total
|
| 109 |
+
- Put the most important keywords first
|
| 110 |
|
| 111 |
Some tips and things to be sure to remove:
|
| 112 |
- Remove any conversational or instructional phrases
|
|
|
|
| 115 |
- Remove style instructions (exmaple: "in the style of", engaging, short, long)
|
| 116 |
- Remove lenght instruction (example: essay, article, letter, etc)
|
| 117 |
|
| 118 |
+
You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the query
|
| 119 |
|
| 120 |
Example:
|
| 121 |
Question: How do I bake chocolate chip cookies from scratch?
|
| 122 |
+
chocolate chip cookies recipe from scratch**
|
| 123 |
Example:
|
| 124 |
Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
|
| 125 |
+
Marie Curie timeline**
|
| 126 |
Example:
|
| 127 |
Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
|
| 128 |
+
geopolitics nato russia**
|
| 129 |
Example:
|
| 130 |
Question: Write an engaging LinkedIn post about Andrew Ng
|
| 131 |
+
Andrew Ng**
|
| 132 |
Example:
|
| 133 |
Question: Write a short article about the solar system in the style of Carl Sagan
|
| 134 |
+
solar system**
|
| 135 |
Example:
|
| 136 |
Question: Should I use Kubernetes? Answer in the style of Gilfoyle from the TV show Silicon Valley
|
| 137 |
+
Kubernetes decision**
|
| 138 |
Example:
|
| 139 |
Question: Biography of Napoleon. Include a table with the major events.
|
| 140 |
+
napoleon biography events**
|
| 141 |
Example:
|
| 142 |
Question: Write a short article on the history of the United States. Include a table with the major events.
|
| 143 |
+
united states history events**
|
| 144 |
Example:
|
| 145 |
Question: Write a short article about the solar system in the style of donald trump
|
| 146 |
+
solar system**
|
| 147 |
Exmaple:
|
| 148 |
Question: Write a short linkedin about how the "freakeconomics" book previsions didn't pan out
|
| 149 |
+
freakeconomics book predictions failed**
|
| 150 |
"""
|
| 151 |
)
|
| 152 |
human_message = HumanMessage(
|
| 153 |
content=f"""
|
| 154 |
Question: {query}
|
| 155 |
+
|
| 156 |
"""
|
| 157 |
)
|
| 158 |
return [system_message, human_message]
|
|
|
|
| 235 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
| 236 |
return response.content
|
| 237 |
|
| 238 |
+
def get_context_size(chat_llm):
|
| 239 |
+
if isinstance(chat_llm, ChatOpenAI):
|
| 240 |
+
if chat_llm.model_name.startswith("gpt-4"):
|
| 241 |
+
return 128000
|
| 242 |
+
else:
|
| 243 |
+
return 16385
|
| 244 |
+
if isinstance(chat_llm, ChatFireworks):
|
| 245 |
+
return 8192
|
| 246 |
+
if isinstance(chat_llm, ChatGroq):
|
| 247 |
+
return 37862
|
| 248 |
+
if isinstance(chat_llm, ChatOllama):
|
| 249 |
+
return 8192
|
| 250 |
+
if isinstance(chat_llm, ChatCohere):
|
| 251 |
+
return 128000
|
| 252 |
+
if isinstance(chat_llm, ChatBedrock):
|
| 253 |
+
if chat_llm.model_id.startswith("anthropic.claude-3"):
|
| 254 |
+
return 200000
|
| 255 |
+
if chat_llm.model_id.startswith("anthropic.claude"):
|
| 256 |
+
return 100000
|
| 257 |
+
if chat_llm.model_id.startswith("mistral"):
|
| 258 |
+
if chat_llm.model_id.startswith("mistral.mixtral-8x7b"):
|
| 259 |
+
return 4096
|
| 260 |
+
else:
|
| 261 |
+
return 8192
|
| 262 |
+
return 4096
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
| 266 |
+
done = False
|
| 267 |
+
while not done:
|
| 268 |
+
unique_docs = vectorstore.similarity_search(
|
| 269 |
+
search_query, k=top_k, callbacks=callbacks, verbose=True)
|
| 270 |
+
context = format_docs(unique_docs)
|
| 271 |
+
prompt = get_rag_prompt_template().format(query=question, context=context)
|
| 272 |
+
nbr_tokens = chat_llm.get_num_tokens(prompt)
|
| 273 |
+
if top_k <= 1 or nbr_tokens <= get_context_size(chat_llm) - 768:
|
| 274 |
+
done = True
|
| 275 |
+
else:
|
| 276 |
+
top_k = int(top_k * 0.75)
|
| 277 |
+
|
| 278 |
+
return prompt
|
| 279 |
|
| 280 |
def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
| 281 |
+
prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
|
| 282 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
| 283 |
+
return response.content
|