Spaces:
Paused
Paused
Eddie Pick
commited on
By default now use spacy for retrieval and augmentation (vs embeddings)
Browse files- nlp_rag.py +144 -0
- requirements.txt +2 -1
- search_agent.py +40 -19
- spacy.ipynb +0 -0
- web_rag.py +2 -2
nlp_rag.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spacy
|
| 2 |
+
from itertools import groupby
|
| 3 |
+
from operator import itemgetter
|
| 4 |
+
from langsmith import traceable
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
def get_nlp_model():
|
| 9 |
+
if not spacy.util.is_package("en_core_web_md"):
|
| 10 |
+
print("Downloading en_core_web_md model...")
|
| 11 |
+
spacy.cli.download("en_core_web_md")
|
| 12 |
+
print("Model downloaded successfully!")
|
| 13 |
+
nlp = spacy.load("en_core_web_md")
|
| 14 |
+
return nlp
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def recursive_split_documents(contents, max_chunk_size=1000, overlap=100):
|
| 18 |
+
from langchain_core.documents.base import Document
|
| 19 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 20 |
+
|
| 21 |
+
documents = []
|
| 22 |
+
for content in contents:
|
| 23 |
+
try:
|
| 24 |
+
page_content = content['page_content']
|
| 25 |
+
if page_content:
|
| 26 |
+
metadata = {'title': content['title'], 'source': content['link']}
|
| 27 |
+
doc = Document(page_content=content['page_content'], metadata=metadata)
|
| 28 |
+
documents.append(doc)
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"Error processing content for {content['link']}: {e}")
|
| 31 |
+
|
| 32 |
+
# Initialize recursive text splitter
|
| 33 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=max_chunk_size, chunk_overlap=overlap)
|
| 34 |
+
|
| 35 |
+
# Split documents
|
| 36 |
+
split_documents = text_splitter.split_documents(documents)
|
| 37 |
+
|
| 38 |
+
# Convert split documents to the same format as recursive_split
|
| 39 |
+
chunks = []
|
| 40 |
+
for doc in split_documents:
|
| 41 |
+
chunk = {
|
| 42 |
+
'text': doc.page_content,
|
| 43 |
+
'metadata': {
|
| 44 |
+
'title': doc.metadata.get('title', ''),
|
| 45 |
+
'source': doc.metadata.get('source', '')
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
chunks.append(chunk)
|
| 49 |
+
|
| 50 |
+
return chunks
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
|
| 54 |
+
# Precompute query vector and its norm
|
| 55 |
+
query_vector = nlp(query).vector
|
| 56 |
+
query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon to avoid division by zero
|
| 57 |
+
|
| 58 |
+
# Check if chunks have precomputed vectors; if not, compute them
|
| 59 |
+
if 'vector' not in chunks[0]:
|
| 60 |
+
texts = [chunk['text'] for chunk in chunks]
|
| 61 |
+
|
| 62 |
+
# Process texts in batches using nlp.pipe()
|
| 63 |
+
batch_size = 1000 # Adjust based on available memory
|
| 64 |
+
with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
|
| 65 |
+
docs = nlp.pipe(texts, batch_size=batch_size)
|
| 66 |
+
|
| 67 |
+
# Add vectors to chunks
|
| 68 |
+
for chunk, doc in zip(chunks, docs):
|
| 69 |
+
chunk['vector'] = doc.vector
|
| 70 |
+
|
| 71 |
+
# Prepare chunk vectors and norms
|
| 72 |
+
chunk_vectors = np.array([chunk['vector'] for chunk in chunks])
|
| 73 |
+
chunk_norms = np.linalg.norm(chunk_vectors, axis=1) + 1e-8 # Add epsilon to avoid division by zero
|
| 74 |
+
|
| 75 |
+
# Compute similarities
|
| 76 |
+
similarities = np.dot(chunk_vectors, query_vector) / (chunk_norms * query_norm)
|
| 77 |
+
|
| 78 |
+
# Filter and sort results
|
| 79 |
+
relevant_chunks = [
|
| 80 |
+
(chunk, sim) for chunk, sim in zip(chunks, similarities) if sim > similarity_threshold
|
| 81 |
+
]
|
| 82 |
+
relevant_chunks.sort(key=lambda x: x[1], reverse=True)
|
| 83 |
+
|
| 84 |
+
return relevant_chunks[:top_n]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Perform semantic search using spaCy
|
| 88 |
+
def semantic_search(query, chunks, nlp, similarity_threshold=0.5, top_n=10):
|
| 89 |
+
import numpy as np
|
| 90 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 91 |
+
|
| 92 |
+
# Precompute query vector and its norm with epsilon to prevent division by zero
|
| 93 |
+
with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
|
| 94 |
+
query_vector = nlp(query).vector
|
| 95 |
+
query_norm = np.linalg.norm(query_vector) + 1e-8 # Add epsilon
|
| 96 |
+
|
| 97 |
+
# Prepare texts from chunks
|
| 98 |
+
texts = [chunk['text'] for chunk in chunks]
|
| 99 |
+
|
| 100 |
+
# Function to process each text and compute its vector
|
| 101 |
+
def compute_vector(text):
|
| 102 |
+
with nlp.disable_pipes(*[pipe for pipe in nlp.pipe_names if pipe != 'tok2vec']):
|
| 103 |
+
doc = nlp(text)
|
| 104 |
+
vector = doc.vector
|
| 105 |
+
return vector
|
| 106 |
+
|
| 107 |
+
# Process texts in parallel using ThreadPoolExecutor
|
| 108 |
+
with ThreadPoolExecutor() as executor:
|
| 109 |
+
chunk_vectors = list(executor.map(compute_vector, texts))
|
| 110 |
+
|
| 111 |
+
chunk_vectors = np.array(chunk_vectors)
|
| 112 |
+
chunk_norms = np.linalg.norm(chunk_vectors, axis=1) + 1e-8 # Add epsilon
|
| 113 |
+
|
| 114 |
+
# Compute similarities using vectorized operations
|
| 115 |
+
similarities = np.dot(chunk_vectors, query_vector) / (chunk_norms * query_norm)
|
| 116 |
+
|
| 117 |
+
# Filter and sort results
|
| 118 |
+
relevant_chunks = [
|
| 119 |
+
(chunk, sim) for chunk, sim in zip(chunks, similarities) if sim > similarity_threshold
|
| 120 |
+
]
|
| 121 |
+
relevant_chunks.sort(key=lambda x: x[1], reverse=True)
|
| 122 |
+
|
| 123 |
+
return relevant_chunks[:top_n]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@traceable(run_type="llm", name="nlp_rag")
|
| 127 |
+
def query_rag(chat_llm, query, relevant_results):
|
| 128 |
+
import web_rag as wr
|
| 129 |
+
|
| 130 |
+
formatted_chunks = ""
|
| 131 |
+
for chunk, similarity in relevant_results:
|
| 132 |
+
formatted_chunk = f"""
|
| 133 |
+
<source>
|
| 134 |
+
<url>{chunk['metadata']['source']}</url>
|
| 135 |
+
<title>{chunk['metadata']['title']}</title>
|
| 136 |
+
<text>{chunk['text']}</text>
|
| 137 |
+
</source>
|
| 138 |
+
"""
|
| 139 |
+
formatted_chunks += formatted_chunk
|
| 140 |
+
|
| 141 |
+
prompt = wr.get_rag_prompt_template().format(query=query, context=formatted_chunks)
|
| 142 |
+
|
| 143 |
+
draft = chat_llm.invoke(prompt).content
|
| 144 |
+
return draft
|
requirements.txt
CHANGED
|
@@ -30,4 +30,5 @@ tiktoken
|
|
| 30 |
transformers >= 4.44.2
|
| 31 |
rich >= 13.8.1
|
| 32 |
trafilatura >= 1.12.2
|
| 33 |
-
watchdog >= 2.1.5, < 5.0.0
|
|
|
|
|
|
| 30 |
transformers >= 4.44.2
|
| 31 |
rich >= 13.8.1
|
| 32 |
trafilatura >= 1.12.2
|
| 33 |
+
watchdog >= 2.1.5, < 5.0.0
|
| 34 |
+
spacy >= 3.6.1, < 4.0.0
|
search_agent.py
CHANGED
|
@@ -10,7 +10,7 @@ Usage:
|
|
| 10 |
[--copywrite]
|
| 11 |
[--max_pages=num]
|
| 12 |
[--max_extracts=num]
|
| 13 |
-
[--
|
| 14 |
[--output=text]
|
| 15 |
[--verbose]
|
| 16 |
SEARCH_QUERY
|
|
@@ -23,10 +23,10 @@ Options:
|
|
| 23 |
-d domain --domain=domain Limit search to a specific domain
|
| 24 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
| 25 |
-m model --model=model Use a specific model [default: openai/gpt-4o-mini]
|
| 26 |
-
-e model --embedding_model=model Use
|
| 27 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
| 28 |
-x num --max_extracts=num Max number of page extract to consider [default: 7]
|
| 29 |
-
-
|
| 30 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
| 31 |
-v --verbose Print verbose output [default: False]
|
| 32 |
|
|
@@ -49,6 +49,7 @@ import web_rag as wr
|
|
| 49 |
import web_crawler as wc
|
| 50 |
import copywriter as cw
|
| 51 |
import models as md
|
|
|
|
| 52 |
|
| 53 |
console = Console()
|
| 54 |
dotenv.load_dotenv()
|
|
@@ -91,32 +92,35 @@ def main(arguments):
|
|
| 91 |
max_pages=int(arguments["--max_pages"])
|
| 92 |
max_extract=int(arguments["--max_extracts"])
|
| 93 |
output=arguments["--output"]
|
| 94 |
-
use_selenium=arguments["--
|
| 95 |
query = arguments["SEARCH_QUERY"]
|
| 96 |
|
| 97 |
chat = md.get_model(model, temperature)
|
| 98 |
-
if embedding_model
|
| 99 |
-
|
| 100 |
-
|
| 101 |
else:
|
| 102 |
embedding_model = md.get_embedding_model(embedding_model)
|
|
|
|
| 103 |
|
| 104 |
if verbose:
|
| 105 |
model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
|
| 106 |
-
embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
|
| 107 |
-
console.log(f"Using model: {model_name}")
|
| 108 |
console.log(f"Using embedding model: {embedding_model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
| 111 |
-
|
| 112 |
-
if len(
|
| 113 |
-
|
| 114 |
-
console.log(f"Optimized search query: [bold blue]{
|
| 115 |
|
| 116 |
with console.status(
|
| 117 |
-
f"[bold green]Searching sources using the optimized query: {
|
| 118 |
):
|
| 119 |
-
sources = wc.get_sources(
|
| 120 |
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
|
| 121 |
|
| 122 |
with console.status(
|
|
@@ -125,11 +129,28 @@ def main(arguments):
|
|
| 125 |
contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
|
| 126 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
with console.status("[bold green]Writing content", spinner='dots8Bit'):
|
| 132 |
-
draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract)
|
| 133 |
|
| 134 |
console.rule(f"[bold green]Response")
|
| 135 |
if output == "text":
|
|
|
|
| 10 |
[--copywrite]
|
| 11 |
[--max_pages=num]
|
| 12 |
[--max_extracts=num]
|
| 13 |
+
[--use_browser]
|
| 14 |
[--output=text]
|
| 15 |
[--verbose]
|
| 16 |
SEARCH_QUERY
|
|
|
|
| 23 |
-d domain --domain=domain Limit search to a specific domain
|
| 24 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
| 25 |
-m model --model=model Use a specific model [default: openai/gpt-4o-mini]
|
| 26 |
+
-e model --embedding_model=model Use an embedding model
|
| 27 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
| 28 |
-x num --max_extracts=num Max number of page extract to consider [default: 7]
|
| 29 |
+
-b --use_browser Use browser to fetch content from the web [default: False]
|
| 30 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
| 31 |
-v --verbose Print verbose output [default: False]
|
| 32 |
|
|
|
|
| 49 |
import web_crawler as wc
|
| 50 |
import copywriter as cw
|
| 51 |
import models as md
|
| 52 |
+
import nlp_rag as nr
|
| 53 |
|
| 54 |
console = Console()
|
| 55 |
dotenv.load_dotenv()
|
|
|
|
| 92 |
max_pages=int(arguments["--max_pages"])
|
| 93 |
max_extract=int(arguments["--max_extracts"])
|
| 94 |
output=arguments["--output"]
|
| 95 |
+
use_selenium=arguments["--use_browser"]
|
| 96 |
query = arguments["SEARCH_QUERY"]
|
| 97 |
|
| 98 |
chat = md.get_model(model, temperature)
|
| 99 |
+
if embedding_model is None:
|
| 100 |
+
use_nlp = True
|
| 101 |
+
nlp = nr.get_nlp_model()
|
| 102 |
else:
|
| 103 |
embedding_model = md.get_embedding_model(embedding_model)
|
| 104 |
+
use_nlp = False
|
| 105 |
|
| 106 |
if verbose:
|
| 107 |
model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat)
|
|
|
|
|
|
|
| 108 |
console.log(f"Using embedding model: {embedding_model_name}")
|
| 109 |
+
if not use_nlp:
|
| 110 |
+
embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model)
|
| 111 |
+
console.log(f"Using model: {embedding_model_name}")
|
| 112 |
+
|
| 113 |
|
| 114 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
| 115 |
+
optimized_search_query = wr.optimize_search_query(chat, query)
|
| 116 |
+
if len(optimized_search_query) < 3:
|
| 117 |
+
optimized_search_query = query
|
| 118 |
+
console.log(f"Optimized search query: [bold blue]{optimized_search_query}")
|
| 119 |
|
| 120 |
with console.status(
|
| 121 |
+
f"[bold green]Searching sources using the optimized query: {optimized_search_query}"
|
| 122 |
):
|
| 123 |
+
sources = wc.get_sources(optimized_search_query, max_pages=max_pages, domain=domain)
|
| 124 |
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
|
| 125 |
|
| 126 |
with console.status(
|
|
|
|
| 129 |
contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium)
|
| 130 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
| 131 |
|
| 132 |
+
if use_nlp:
|
| 133 |
+
with console.status(f"[bold green]Splitting {len(contents)} sources for content", spinner="growVertical"):
|
| 134 |
+
chunks = nr.recursive_split_documents(contents)
|
| 135 |
+
#chunks = nr.chunk_contents(nlp, contents)
|
| 136 |
+
console.log(f"Split {len(contents)} sources into {len(chunks)} chunks")
|
| 137 |
+
with console.status(f"[bold green]Searching relevant chunks", spinner="growVertical"):
|
| 138 |
+
import time
|
| 139 |
+
|
| 140 |
+
start_time = time.time()
|
| 141 |
+
relevant_results = nr.semantic_search(optimized_search_query, chunks, nlp, top_n=max_extract)
|
| 142 |
+
end_time = time.time()
|
| 143 |
+
execution_time = end_time - start_time
|
| 144 |
+
console.log(f"Semantic search took {execution_time:.2f} seconds")
|
| 145 |
+
console.log(f"Found {len(relevant_results)} relevant chunks")
|
| 146 |
+
with console.status(f"[bold green]Writing content", spinner="growVertical"):
|
| 147 |
+
draft = nr.query_rag(chat, query, relevant_results)
|
| 148 |
+
else:
|
| 149 |
+
with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
|
| 150 |
+
vector_store = wc.vectorize(contents, embedding_model)
|
| 151 |
+
with console.status("[bold green]Writing content", spinner='dots8Bit'):
|
| 152 |
+
draft = wr.query_rag(chat, query, optimized_search_query, vector_store, top_k = max_extract)
|
| 153 |
|
|
|
|
|
|
|
| 154 |
|
| 155 |
console.rule(f"[bold green]Response")
|
| 156 |
if output == "text":
|
spacy.ipynb
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
web_rag.py
CHANGED
|
@@ -74,13 +74,13 @@ def get_optimized_search_messages(query):
|
|
| 74 |
chocolate chip cookies recipe from scratch**
|
| 75 |
Example:
|
| 76 |
Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
|
| 77 |
-
|
| 78 |
Example:
|
| 79 |
Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
|
| 80 |
geopolitics nato russia**
|
| 81 |
Example:
|
| 82 |
Question: Write an engaging LinkedIn post about Andrew Ng
|
| 83 |
-
|
| 84 |
Example:
|
| 85 |
Question: Write a short article about the solar system in the style of Carl Sagan
|
| 86 |
solar system**
|
|
|
|
| 74 |
chocolate chip cookies recipe from scratch**
|
| 75 |
Example:
|
| 76 |
Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
|
| 77 |
+
Marie Curie timeline**
|
| 78 |
Example:
|
| 79 |
Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
|
| 80 |
geopolitics nato russia**
|
| 81 |
Example:
|
| 82 |
Question: Write an engaging LinkedIn post about Andrew Ng
|
| 83 |
+
Andrew Ng**
|
| 84 |
Example:
|
| 85 |
Question: Write a short article about the solar system in the style of Carl Sagan
|
| 86 |
solar system**
|