Spaces:
Paused
Paused
| """ | |
| search_agent.py | |
| Usage: | |
| search_agent.py | |
| [--domain=domain] | |
| [--provider=provider] | |
| [--model=model] | |
| [--embedding_model=model] | |
| [--temperature=temp] | |
| [--copywrite] | |
| [--max_pages=num] | |
| [--max_extracts=num] | |
| [--use_browser] | |
| [--output=text] | |
| [--verbose] | |
| SEARCH_QUERY | |
| search_agent.py --version | |
| Options: | |
| -h --help Show this screen. | |
| --version Show version. | |
| -c --copywrite First produce a draft, review it and rewrite for a final text | |
| -d domain --domain=domain Limit search to a specific domain | |
| -t temp --temperature=temp Set the temperature of the LLM [default: 0.0] | |
| -m model --model=model Use a specific model [default: hf:Qwen/Qwen2.5-72B-Instruct] | |
| -e model --embedding_model=model Use an embedding model | |
| -n num --max_pages=num Max number of pages to retrieve [default: 10] | |
| -x num --max_extracts=num Max number of page extract to consider [default: 7] | |
| -b --use_browser Use browser to fetch content from the web [default: False] | |
| -o text --output=text Output format (choices: text, markdown) [default: markdown] | |
| -v --verbose Print verbose output [default: False] | |
| """ | |
| import os | |
| from docopt import docopt | |
| import dotenv | |
| from langchain.callbacks import LangChainTracer | |
| from langsmith import Client, traceable | |
| from rich.console import Console | |
| from rich.markdown import Markdown | |
| import web_rag as wr | |
| import web_crawler as wc | |
| import copywriter as cw | |
| import models as md | |
| import nlp_rag as nr | |
| # Initialize console for rich text output | |
| console = Console() | |
| # Load environment variables from a .env file | |
| dotenv.load_dotenv() | |
| def get_selenium_driver(): | |
| """Initialize and return a headless Selenium WebDriver for Chrome.""" | |
| from selenium import webdriver | |
| from selenium.webdriver.chrome.options import Options | |
| from selenium.common.exceptions import WebDriverException | |
| chrome_options = Options() | |
| chrome_options.add_argument("--headless") | |
| chrome_options.add_argument("--disable-extensions") | |
| chrome_options.add_argument("--disable-gpu") | |
| chrome_options.add_argument("--no-sandbox") | |
| chrome_options.add_argument("--disable-dev-shm-usage") | |
| chrome_options.add_argument("--remote-debugging-port=9222") | |
| chrome_options.add_argument('--blink-settings=imagesEnabled=false') | |
| chrome_options.add_argument("--window-size=1920,1080") | |
| try: | |
| driver = webdriver.Chrome(options=chrome_options) | |
| return driver | |
| except WebDriverException as e: | |
| print(f"Error creating Selenium WebDriver: {e}") | |
| return None | |
| # Initialize callbacks list | |
| callbacks = [] | |
| # Add LangChainTracer to callbacks if API key is set | |
| if os.getenv("LANGCHAIN_API_KEY"): | |
| callbacks.append( | |
| LangChainTracer(client=Client()) | |
| ) | |
| def main(arguments): | |
| """Main function to execute the search agent logic.""" | |
| verbose = arguments["--verbose"] | |
| copywrite_mode = arguments["--copywrite"] | |
| model = arguments["--model"] | |
| embedding_model = arguments["--embedding_model"] | |
| temperature = float(arguments["--temperature"]) | |
| domain = arguments["--domain"] | |
| max_pages = int(arguments["--max_pages"]) | |
| max_extract = int(arguments["--max_extracts"]) | |
| output = arguments["--output"] | |
| use_selenium = arguments["--use_browser"] | |
| query = arguments["SEARCH_QUERY"] | |
| # Get the language model based on the provided model name and temperature | |
| chat = md.get_model(model, temperature) | |
| # If no embedding model is provided, use spacy for semantic search | |
| if embedding_model is None: | |
| use_nlp = True | |
| nlp = nr.get_nlp_model() | |
| else: | |
| use_nlp = False | |
| embedding_model = md.get_embedding_model(embedding_model) | |
| # Log model details if verbose mode is enabled | |
| if verbose: | |
| model_name = getattr(chat, 'model_name', None) or getattr(chat, 'model', None) or getattr(chat, 'model_id', None) or str(chat) | |
| console.log(f"Using model: {model_name}") | |
| if not use_nlp: | |
| embedding_model_name = getattr(embedding_model, 'model_name', None) or getattr(embedding_model, 'model', None) or getattr(embedding_model, 'model_id', None) or str(embedding_model) | |
| console.log(f"Using embedding model: {embedding_model_name}") | |
| # Optimize the search query | |
| with console.status(f"[bold green]Optimizing query for search: {query}"): | |
| optimized_search_query = wr.optimize_search_query(chat, query) | |
| if len(optimized_search_query) < 3: | |
| optimized_search_query = query | |
| console.log(f"Optimized search query: [bold blue]{optimized_search_query}") | |
| # Retrieve sources using the optimized query | |
| with console.status( | |
| f"[bold green]Searching sources using the optimized query: {optimized_search_query}" | |
| ): | |
| sources = wc.get_sources(optimized_search_query, max_pages=max_pages, domain=domain) | |
| console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}") | |
| # Fetch content from the retrieved sources | |
| with console.status( | |
| f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical" | |
| ): | |
| contents = wc.get_links_contents(sources, get_selenium_driver, use_selenium=use_selenium) | |
| console.log(f"Managed to extract content from {len(contents)} sources") | |
| # Process content using spaCy or embedding model | |
| if use_nlp: | |
| with console.status(f"[bold green]Splitting {len(contents)} sources for content", spinner="growVertical"): | |
| chunks = nr.recursive_split_documents(contents) | |
| console.log(f"Split {len(contents)} sources into {len(chunks)} chunks") | |
| with console.status(f"[bold green]Searching relevant chunks", spinner="growVertical"): | |
| relevant_results = nr.semantic_search(optimized_search_query, chunks, nlp, top_n=max_extract) | |
| console.log(f"Found {len(relevant_results)} relevant chunks") | |
| with console.status(f"[bold green]Writing content", spinner="growVertical"): | |
| draft = nr.query_rag(chat, query, relevant_results) | |
| else: | |
| with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"): | |
| vector_store = wc.vectorize(contents, embedding_model) | |
| with console.status("[bold green]Writing content", spinner='dots8Bit'): | |
| draft = wr.query_rag(chat, query, optimized_search_query, vector_store, top_k=max_extract) | |
| # If copywrite mode is enabled, generate comments and final text | |
| if(copywrite_mode): | |
| with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"): | |
| comments = cw.generate_comments(chat, query, draft) | |
| with console.status("[bold green]Writing the final text", spinner="dots8Bit"): | |
| final_text = cw.generate_final_text(chat, query, draft, comments) | |
| else: | |
| final_text = draft | |
| # Output the answer | |
| console.rule(f"[bold green]Response") | |
| if output == "text": | |
| console.print(final_text) | |
| else: | |
| console.print(Markdown(final_text)) | |
| console.rule("[bold green]") | |
| return final_text | |
| if __name__ == '__main__': | |
| # Parse command-line arguments and execute the main function | |
| arguments = docopt(__doc__, version='Search Agent 0.1') | |
| main(arguments) | |