|
|
|
|
|
import os |
|
|
import asyncio |
|
|
import json |
|
|
import logging |
|
|
import random |
|
|
import re |
|
|
import time |
|
|
from typing import AsyncGenerator, Optional, Tuple, List, Dict |
|
|
from urllib.parse import quote_plus, urlparse, unquote |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.responses import StreamingResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from dotenv import load_dotenv |
|
|
import aiohttp |
|
|
from bs4 import BeautifulSoup |
|
|
from fake_useragent import UserAgent |
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
load_dotenv() |
|
|
|
|
|
LLM_API_KEY = os.getenv("LLM_API_KEY") |
|
|
if not LLM_API_KEY: |
|
|
raise RuntimeError("LLM_API_KEY must be set in a .env file.") |
|
|
else: |
|
|
logging.info("LLM API Key loaded successfully.") |
|
|
|
|
|
|
|
|
LLM_API_URL = "https://api.typegpt.net/v1/chat/completions" |
|
|
LLM_MODEL = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" |
|
|
MAX_SOURCES_TO_PROCESS = 20 |
|
|
MAX_CONCURRENT_REQUESTS = 2 |
|
|
SEARCH_TIMEOUT = 300 |
|
|
|
|
|
TOTAL_TIMEOUT = 1800 |
|
|
REQUEST_DELAY = 3.0 |
|
|
RETRY_ATTEMPTS = 5 |
|
|
RETRY_DELAY = 5.0 |
|
|
USER_AGENT_ROTATION = True |
|
|
|
|
|
CONTEXT_WINDOW_SIZE = 10_000_000 |
|
|
MAX_CONTEXT_SIZE = 2_000_000 |
|
|
|
|
|
RESPECT_ROBOTS_TXT = False |
|
|
|
|
|
|
|
|
try: |
|
|
ua = UserAgent() |
|
|
except: |
|
|
class SimpleUA: |
|
|
def random(self): |
|
|
return random.choice([ |
|
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36", |
|
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36", |
|
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:129.0) Gecko/20100101 Firefox/129.0" |
|
|
]) |
|
|
ua = SimpleUA() |
|
|
|
|
|
LLM_HEADERS = { |
|
|
"Authorization": f"Bearer {LLM_API_KEY}", |
|
|
"Content-Type": "application/json", |
|
|
"Accept": "application/json" |
|
|
} |
|
|
|
|
|
class DeepResearchRequest(BaseModel): |
|
|
query: str |
|
|
search_time: int = 300 |
|
|
|
|
|
class SearchRequest(BaseModel): |
|
|
query: str |
|
|
search_time: int = 60 |
|
|
max_results: int = 20 |
|
|
|
|
|
app = FastAPI( |
|
|
title="AI Deep Research API", |
|
|
description="Provides comprehensive research reports from real web searches within 5 minutes.", |
|
|
version="3.0.0" |
|
|
) |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"] |
|
|
) |
|
|
|
|
|
def extract_json_from_llm_response(text: str) -> Optional[list]: |
|
|
"""Extract JSON array from LLM response text.""" |
|
|
match = re.search(r'\[.*\]', text, re.DOTALL) |
|
|
if match: |
|
|
try: |
|
|
return json.loads(match.group(0)) |
|
|
except json.JSONDecodeError: |
|
|
return None |
|
|
return None |
|
|
|
|
|
async def get_real_user_agent() -> str: |
|
|
"""Get a realistic user agent string.""" |
|
|
try: |
|
|
if isinstance(ua, UserAgent): |
|
|
return ua.random |
|
|
return ua.random() |
|
|
except: |
|
|
return "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36" |
|
|
|
|
|
def clean_url(url: str) -> str: |
|
|
"""Clean up and normalize URLs.""" |
|
|
if not url: |
|
|
return "" |
|
|
|
|
|
|
|
|
if url.startswith('//duckduckgo.com/l/') or url.startswith('/l/?'): |
|
|
if url.startswith('//'): |
|
|
url = f"https:{url}" |
|
|
elif url.startswith('/'): |
|
|
url = f"https://duckduckgo.com{url}" |
|
|
try: |
|
|
parsed = urlparse(url) |
|
|
query_params = parsed.query |
|
|
if 'uddg=' in query_params: |
|
|
match = re.search(r'uddg=([^&]+)', query_params) |
|
|
if match: |
|
|
return unquote(match.group(1)) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if url.startswith('//'): |
|
|
url = 'https:' + url |
|
|
elif not url.startswith(('http://', 'https://')): |
|
|
url = 'https://' + url |
|
|
|
|
|
return url |
|
|
|
|
|
async def check_robots_txt(url: str) -> bool: |
|
|
"""Check if scraping is allowed by robots.txt.""" |
|
|
|
|
|
if not RESPECT_ROBOTS_TXT: |
|
|
return True |
|
|
try: |
|
|
domain_match = re.search(r'https?://([^/]+)', url) |
|
|
if not domain_match: |
|
|
return False |
|
|
|
|
|
domain = domain_match.group(1) |
|
|
robots_url = f"https://{domain}/robots.txt" |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
headers = {'User-Agent': await get_real_user_agent()} |
|
|
async with session.get(robots_url, headers=headers, timeout=5) as response: |
|
|
if response.status == 200: |
|
|
robots = await response.text() |
|
|
if "Disallow: /" in robots: |
|
|
return False |
|
|
path = re.sub(r'https?://[^/]+', '', url) |
|
|
if any(f"Disallow: {p}" in robots for p in [path, path.rstrip('/') + '/']): |
|
|
return False |
|
|
return True |
|
|
except Exception as e: |
|
|
logging.warning(f"Could not check robots.txt for {url}: {e}") |
|
|
|
|
|
return True |
|
|
|
|
|
async def fetch_search_results(query: str, max_results: int = 5) -> List[dict]: |
|
|
"""Perform a real search using DuckDuckGo (Lite/HTML) with multi-endpoint fallback to reduce 202 issues.""" |
|
|
ua_hdr = await get_real_user_agent() |
|
|
common_headers = { |
|
|
"User-Agent": ua_hdr, |
|
|
"Accept-Language": "en-US,en;q=0.9", |
|
|
"DNT": "1", |
|
|
"Cache-Control": "no-cache", |
|
|
"Pragma": "no-cache", |
|
|
"Referer": "https://duckduckgo.com/", |
|
|
} |
|
|
|
|
|
|
|
|
endpoints = [ |
|
|
{"name": "lite-get", "method": "GET", "url": lambda q: f"https://lite.duckduckgo.com/lite/?q={quote_plus(q)}&kl=us-en&bing_market=us-en"}, |
|
|
|
|
|
{"name": "lite-post", "method": "POST", "url": lambda q: f"https://lite.duckduckgo.com/lite/?q={quote_plus(q)}&kl=us-en&bing_market=us-en"}, |
|
|
{"name": "html-mirror", "method": "GET", "url": lambda q: f"https://html.duckduckgo.com/html/?q={quote_plus(q)}"}, |
|
|
{"name": "html", "method": "GET", "url": lambda q: f"https://duckduckgo.com/html/?q={quote_plus(q)}"}, |
|
|
] |
|
|
|
|
|
def parse_results_from_html(html: str) -> List[dict]: |
|
|
soup = BeautifulSoup(html, 'html.parser') |
|
|
results: List[dict] = [] |
|
|
|
|
|
|
|
|
candidates = soup.select('.result__body') |
|
|
if not candidates: |
|
|
candidates = soup.select('.result') |
|
|
|
|
|
for result in candidates: |
|
|
try: |
|
|
title_elem = result.select_one('.result__title .result__a') or result.select_one('.result__a') |
|
|
if not title_elem: |
|
|
|
|
|
title_elem = result.find('a') |
|
|
if not title_elem: |
|
|
continue |
|
|
link = title_elem.get('href') |
|
|
if not link: |
|
|
continue |
|
|
snippet_elem = result.select_one('.result__snippet') or result.find('p') |
|
|
clean_link = clean_url(link) |
|
|
if not clean_link or clean_link.startswith('javascript:'): |
|
|
continue |
|
|
snippet = snippet_elem.get_text(strip=True) if snippet_elem else "" |
|
|
title_text = title_elem.get_text(strip=True) |
|
|
results.append({'title': title_text, 'link': clean_link, 'snippet': snippet}) |
|
|
except Exception as e: |
|
|
logging.warning(f"Error parsing search result: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
if not results: |
|
|
lite_links = soup.select('a[href*="/l/?uddg="]') |
|
|
for a in lite_links: |
|
|
try: |
|
|
href = a.get('href') |
|
|
title_text = a.get_text(strip=True) |
|
|
if not href or not title_text: |
|
|
continue |
|
|
clean_link = clean_url(href) |
|
|
if not clean_link or clean_link.startswith('javascript:'): |
|
|
continue |
|
|
results.append({'title': title_text, 'link': clean_link, 'snippet': ''}) |
|
|
if len(results) >= max_results: |
|
|
break |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
|
|
|
if not results: |
|
|
anchors = soup.find_all('a', href=True) |
|
|
for a in anchors: |
|
|
text = a.get_text(strip=True) |
|
|
href = a['href'] |
|
|
if not text or not href: |
|
|
continue |
|
|
if '/l/?' in href or href.startswith('http') or href.startswith('//'): |
|
|
clean_link = clean_url(href) |
|
|
if clean_link and not clean_link.startswith('javascript:'): |
|
|
results.append({'title': text, 'link': clean_link, 'snippet': ''}) |
|
|
if len(results) >= max_results * 2: |
|
|
break |
|
|
|
|
|
return results[:max_results] |
|
|
|
|
|
for attempt in range(RETRY_ATTEMPTS): |
|
|
try: |
|
|
async with aiohttp.ClientSession() as session: |
|
|
for ep in endpoints: |
|
|
url = ep['url'](query) |
|
|
headers = {**common_headers, "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"} |
|
|
try: |
|
|
if ep['method'] == 'GET': |
|
|
resp = await session.get(url, headers=headers, timeout=12) |
|
|
else: |
|
|
|
|
|
resp = await session.post(url, headers=headers, timeout=12) |
|
|
async with resp as response: |
|
|
if response.status == 200: |
|
|
html = await response.text() |
|
|
results = parse_results_from_html(html) |
|
|
if results: |
|
|
logging.info(f"Found {len(results)} real search results for '{query}' via {ep['name']}") |
|
|
return results |
|
|
|
|
|
logging.warning(f"No results parsed from {ep['name']} for '{query}', trying next endpoint...") |
|
|
continue |
|
|
elif response.status == 202: |
|
|
logging.warning(f"Search attempt {attempt + 1} got 202 at {ep['name']} for '{query}', trying next endpoint") |
|
|
continue |
|
|
else: |
|
|
logging.warning(f"Search failed with status {response.status} at {ep['name']} for '{query}'") |
|
|
continue |
|
|
except asyncio.TimeoutError: |
|
|
logging.warning(f"Timeout contacting {ep['name']} for '{query}'") |
|
|
continue |
|
|
except Exception as e: |
|
|
logging.warning(f"Error contacting {ep['name']} for '{query}': {e}") |
|
|
continue |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Search attempt {attempt + 1} failed for '{query}': {e}") |
|
|
|
|
|
|
|
|
if attempt < RETRY_ATTEMPTS - 1: |
|
|
await asyncio.sleep(RETRY_DELAY) |
|
|
|
|
|
logging.error(f"All {RETRY_ATTEMPTS} search attempts failed across endpoints for '{query}'") |
|
|
return [] |
|
|
|
|
|
async def process_web_source(session: aiohttp.ClientSession, source: dict, timeout: int = 15) -> Tuple[str, dict]: |
|
|
"""Process a real web source with improved content extraction and error handling.""" |
|
|
headers = {'User-Agent': await get_real_user_agent()} |
|
|
source_info = source.copy() |
|
|
source_info['link'] = clean_url(source['link']) |
|
|
|
|
|
if not source_info['link'] or not source_info['link'].startswith(('http://', 'https://')): |
|
|
logging.warning(f"Invalid URL: {source_info['link']}") |
|
|
return source.get('snippet', ''), source_info |
|
|
|
|
|
if not await check_robots_txt(source_info['link']): |
|
|
logging.info(f"Scraping disallowed by robots.txt for {source_info['link']}") |
|
|
return source.get('snippet', ''), source_info |
|
|
|
|
|
try: |
|
|
logging.info(f"Processing source: {source_info['link']}") |
|
|
start_time = time.time() |
|
|
|
|
|
if any(source_info['link'].lower().endswith(ext) for ext in ['.pdf', '.doc', '.docx', '.ppt', '.pptx', '.xls', '.xlsx']): |
|
|
logging.info(f"Skipping non-HTML content at {source_info['link']}") |
|
|
return source.get('snippet', ''), source_info |
|
|
|
|
|
await asyncio.sleep(REQUEST_DELAY) |
|
|
|
|
|
async with session.get(source_info['link'], headers=headers, timeout=timeout, ssl=False) as response: |
|
|
if response.status != 200: |
|
|
logging.warning(f"HTTP {response.status} for {source_info['link']}") |
|
|
return source.get('snippet', ''), source_info |
|
|
|
|
|
content_type = response.headers.get('Content-Type', '').lower() |
|
|
if 'text/html' not in content_type: |
|
|
logging.info(f"Non-HTML content at {source_info['link']} (type: {content_type})") |
|
|
return source.get('snippet', ''), source_info |
|
|
|
|
|
html = await response.text() |
|
|
soup = BeautifulSoup(html, "html.parser") |
|
|
|
|
|
for tag in soup(['script', 'style', 'nav', 'footer', 'header', 'aside', 'iframe', 'noscript', 'form']): |
|
|
tag.decompose() |
|
|
|
|
|
selectors_to_try = [ |
|
|
'main', |
|
|
'article', |
|
|
'[role="main"]', |
|
|
'.main-content', |
|
|
'.content', |
|
|
'.article-body', |
|
|
'.post-content', |
|
|
'.entry-content', |
|
|
'#content', |
|
|
'#main', |
|
|
'.main', |
|
|
'.article' |
|
|
] |
|
|
|
|
|
main_content = None |
|
|
for selector in selectors_to_try: |
|
|
main_content = soup.select_one(selector) |
|
|
if main_content: |
|
|
break |
|
|
|
|
|
if not main_content: |
|
|
all_elements = soup.find_all() |
|
|
candidates = [el for el in all_elements if el.name not in ['script', 'style', 'nav', 'footer', 'header']] |
|
|
if candidates: |
|
|
candidates.sort(key=lambda x: len(x.get_text()), reverse=True) |
|
|
main_content = candidates[0] if candidates else soup |
|
|
|
|
|
if not main_content: |
|
|
main_content = soup.find('body') or soup |
|
|
|
|
|
content = " ".join(main_content.stripped_strings) |
|
|
content = re.sub(r'\s+', ' ', content).strip() |
|
|
|
|
|
if len(content.split()) < 50 and len(html) > 10000: |
|
|
paras = soup.find_all('p') |
|
|
content = " ".join([p.get_text() for p in paras if p.get_text().strip()]) |
|
|
content = re.sub(r'\s+', ' ', content).strip() |
|
|
|
|
|
if len(content.split()) < 50: |
|
|
content = " ".join(soup.stripped_strings) |
|
|
content = re.sub(r'\s+', ' ', content).strip() |
|
|
|
|
|
if len(content.split()) < 30: |
|
|
for tag in ['div', 'section', 'article']: |
|
|
for element in soup.find_all(tag): |
|
|
if len(element.get_text().split()) > 200: |
|
|
content = " ".join(element.stripped_strings) |
|
|
content = re.sub(r'\s+', ' ', content).strip() |
|
|
if len(content.split()) >= 30: |
|
|
break |
|
|
if len(content.split()) >= 30: |
|
|
break |
|
|
|
|
|
if len(content.split()) < 30: |
|
|
logging.warning(f"Very little content extracted from {source_info['link']}") |
|
|
return source.get('snippet', ''), source_info |
|
|
|
|
|
source_info['word_count'] = len(content.split()) |
|
|
source_info['processing_time'] = time.time() - start_time |
|
|
return content, source_info |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
logging.warning(f"Timeout while processing {source_info['link']}") |
|
|
return source.get('snippet', ''), source_info |
|
|
except Exception as e: |
|
|
logging.warning(f"Error processing {source_info['link']}: {str(e)[:200]}") |
|
|
return source.get('snippet', ''), source_info |
|
|
|
|
|
async def generate_research_plan(query: str, session: aiohttp.ClientSession) -> List[str]: |
|
|
"""Generate a comprehensive research plan with sub-questions.""" |
|
|
try: |
|
|
plan_prompt = { |
|
|
"model": LLM_MODEL, |
|
|
"messages": [{ |
|
|
"role": "user", |
|
|
"content": f"""Generate 4-8 comprehensive sub-questions for in-depth research on '{query}'. |
|
|
Focus on key aspects that would provide a complete understanding of the topic. |
|
|
Your response MUST be ONLY the raw JSON array with no additional text. |
|
|
Example: [\"What is the historical background of X?\", \"What are the current trends in X?\"]""" |
|
|
}], |
|
|
"temperature": 0.7 |
|
|
} |
|
|
|
|
|
async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=plan_prompt, timeout=30) as response: |
|
|
response.raise_for_status() |
|
|
result = await response.json() |
|
|
|
|
|
if isinstance(result, list): |
|
|
return result |
|
|
elif isinstance(result, dict) and 'choices' in result: |
|
|
content = result['choices'][0]['message']['content'] |
|
|
sub_questions = extract_json_from_llm_response(content) |
|
|
if sub_questions and isinstance(sub_questions, list): |
|
|
cleaned = [] |
|
|
for q in sub_questions: |
|
|
if isinstance(q, str) and q.strip(): |
|
|
cleaned_q = re.sub(r'^[^a-zA-Z0-9]*|[^a-zA-Z0-9]*$', '', q) |
|
|
if cleaned_q: |
|
|
cleaned.append(cleaned_q) |
|
|
return cleaned[:6] |
|
|
|
|
|
return [ |
|
|
f"What is {query} and its key features?", |
|
|
f"How does {query} compare to alternatives?", |
|
|
f"What are the current developments in {query}?", |
|
|
f"What are the main challenges with {query}?", |
|
|
f"What does the future hold for {query}?" |
|
|
] |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to generate research plan: {e}") |
|
|
return [ |
|
|
f"What is {query}?", |
|
|
f"What are the key aspects of {query}?", |
|
|
f"What are current trends in {query}?", |
|
|
f"What are the challenges with {query}?" |
|
|
] |
|
|
|
|
|
async def continuous_search(query: str, search_time: int = 300) -> AsyncGenerator[Dict[str, any], None]: |
|
|
"""Perform continuous searching with retries and diverse queries, yielding updates for each new result.""" |
|
|
start_time = time.time() |
|
|
all_results = [] |
|
|
seen_urls = set() |
|
|
fallback_results = [] |
|
|
|
|
|
query_variations = [ |
|
|
query, |
|
|
f"{query} comparison", |
|
|
f"{query} review", |
|
|
f"{query} latest developments", |
|
|
f"{query} features and benefits", |
|
|
f"{query} challenges and limitations" |
|
|
] |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
iteration = 0 |
|
|
result_count = 0 |
|
|
while time.time() - start_time < search_time: |
|
|
iteration += 1 |
|
|
random.shuffle(query_variations) |
|
|
for q in query_variations: |
|
|
if time.time() - start_time >= search_time: |
|
|
logger.info(f"Search timed out after {search_time} seconds. Found {len(all_results)} results.") |
|
|
break |
|
|
|
|
|
logger.info(f"Iteration {iteration}: Searching for query variation: {q}") |
|
|
yield {"event": "status", "data": f"Searching for '{q}'..."} |
|
|
|
|
|
try: |
|
|
results = await fetch_search_results(q, max_results=5) |
|
|
logger.info(f"Retrieved {len(results)} results for query '{q}'") |
|
|
for result in results: |
|
|
clean_link = clean_url(result['link']) |
|
|
if clean_link and clean_link not in seen_urls: |
|
|
seen_urls.add(clean_link) |
|
|
result['link'] = clean_link |
|
|
all_results.append(result) |
|
|
fallback_results.append(result) |
|
|
result_count += 1 |
|
|
logger.info(f"Added new result: {result['title']} ({result['link']})") |
|
|
yield {"event": "found_result", "data": f"Found result {result_count}: {result['title']} ({result['link']})"} |
|
|
|
|
|
await asyncio.sleep(REQUEST_DELAY) |
|
|
if len(all_results) >= MAX_SOURCES_TO_PROCESS * 1.5: |
|
|
logger.info(f"Reached sufficient results: {len(all_results)}") |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"Error during search for '{q}': {e}") |
|
|
yield {"event": "warning", "data": f"Search error for '{q}': {str(e)[:100]}"} |
|
|
await asyncio.sleep(RETRY_DELAY) |
|
|
|
|
|
if len(all_results) >= MAX_SOURCES_TO_PROCESS * 1.5: |
|
|
break |
|
|
|
|
|
logger.info(f"Completed continuous search. Total results: {len(all_results)}") |
|
|
|
|
|
if len(all_results) < MAX_SOURCES_TO_PROCESS: |
|
|
logger.warning(f"Insufficient results ({len(all_results)}), using fallback results") |
|
|
yield {"event": "warning", "data": f"Insufficient results, using fallback results to reach minimum."} |
|
|
all_results.extend(fallback_results[:MAX_SOURCES_TO_PROCESS - len(all_results)]) |
|
|
|
|
|
if all_results: |
|
|
def score_result(result): |
|
|
query_terms = set(query.lower().split()) |
|
|
title = result['title'].lower() |
|
|
snippet = result['snippet'].lower() |
|
|
matches = sum(1 for term in query_terms if term in title or term in snippet) |
|
|
snippet_length = len(result['snippet'].split()) |
|
|
return matches * 10 + snippet_length |
|
|
|
|
|
all_results.sort(key=score_result, reverse=True) |
|
|
|
|
|
yield {"event": "final_search_results", "data": all_results[:MAX_SOURCES_TO_PROCESS * 2]} |
|
|
|
|
|
async def filter_and_select_sources(results: List[dict]) -> List[dict]: |
|
|
"""Filter and select the best sources from search results.""" |
|
|
if not results: |
|
|
logger.warning("No search results to filter.") |
|
|
return [] |
|
|
|
|
|
logger.info(f"Filtering {len(results)} search results...") |
|
|
|
|
|
domain_counts = defaultdict(int) |
|
|
domain_results = defaultdict(list) |
|
|
for result in results: |
|
|
domain = urlparse(result['link']).netloc |
|
|
domain_counts[domain] += 1 |
|
|
domain_results[domain].append(result) |
|
|
|
|
|
selected = [] |
|
|
for domain, domain_res in domain_results.items(): |
|
|
if len(selected) >= MAX_SOURCES_TO_PROCESS: |
|
|
break |
|
|
if domain_res: |
|
|
selected.append(domain_res[0]) |
|
|
logger.info(f"Selected top result from domain {domain}: {domain_res[0]['link']}") |
|
|
|
|
|
if len(selected) < MAX_SOURCES_TO_PROCESS: |
|
|
domain_quality = {} |
|
|
for domain, domain_res in domain_results.items(): |
|
|
avg_length = sum(len(r['snippet'].split()) for r in domain_res) / len(domain_res) |
|
|
domain_quality[domain] = avg_length |
|
|
|
|
|
sorted_domains = sorted(domain_quality.items(), key=lambda x: x[1], reverse=True) |
|
|
for domain, _ in sorted_domains: |
|
|
if len(selected) >= MAX_SOURCES_TO_PROCESS: |
|
|
break |
|
|
for res in domain_results[domain]: |
|
|
if res not in selected: |
|
|
selected.append(res) |
|
|
logger.info(f"Added additional result from high-quality domain {domain}: {res['link']}") |
|
|
if len(selected) >= MAX_SOURCES_TO_PROCESS: |
|
|
break |
|
|
|
|
|
if len(selected) < MAX_SOURCES_TO_PROCESS: |
|
|
all_results_sorted = sorted(results, key=lambda x: len(x['snippet'].split()), reverse=True) |
|
|
for res in all_results_sorted: |
|
|
if res not in selected: |
|
|
selected.append(res) |
|
|
logger.info(f"Added fallback high-snippet result: {res['link']}") |
|
|
if len(selected) >= MAX_SOURCES_TO_PROCESS: |
|
|
break |
|
|
|
|
|
logger.info(f"Selected {len(selected)} sources after filtering.") |
|
|
return selected[:MAX_SOURCES_TO_PROCESS] |
|
|
|
|
|
async def run_deep_research_stream(query: str, search_time: int = 300) -> AsyncGenerator[str, None]: |
|
|
def format_sse(data: dict) -> str: |
|
|
return f"data: {json.dumps(data)}\n\n" |
|
|
|
|
|
start_time = time.time() |
|
|
processed_sources = 0 |
|
|
successful_sources = 0 |
|
|
total_tokens = 0 |
|
|
|
|
|
try: |
|
|
yield format_sse({ |
|
|
"event": "status", |
|
|
"data": f"Starting deep research on '{query}'. Search time limit: {search_time} seconds." |
|
|
}) |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
yield format_sse({"event": "status", "data": "Generating comprehensive research plan..."}) |
|
|
try: |
|
|
sub_questions = await generate_research_plan(query, session) |
|
|
yield format_sse({"event": "plan", "data": sub_questions}) |
|
|
except Exception as e: |
|
|
yield format_sse({ |
|
|
"event": "error", |
|
|
"data": f"Failed to generate research plan: {str(e)[:200]}" |
|
|
}) |
|
|
sub_questions = [ |
|
|
f"What is {query}?", |
|
|
f"What are the key aspects of {query}?", |
|
|
f"What are current trends in {query}?", |
|
|
f"What are the challenges with {query}?" |
|
|
] |
|
|
yield format_sse({"event": "plan", "data": sub_questions}) |
|
|
|
|
|
yield format_sse({ |
|
|
"event": "status", |
|
|
"data": f"Performing continuous search for up to {search_time} seconds..." |
|
|
}) |
|
|
|
|
|
search_results = [] |
|
|
async for update in continuous_search(query, search_time): |
|
|
if update["event"] == "final_search_results": |
|
|
search_results = update["data"] |
|
|
else: |
|
|
yield format_sse(update) |
|
|
|
|
|
yield format_sse({ |
|
|
"event": "status", |
|
|
"data": f"Found {len(search_results)} potential sources. Selecting the best ones..." |
|
|
}) |
|
|
yield format_sse({ |
|
|
"event": "found_sources", |
|
|
"data": search_results |
|
|
}) |
|
|
|
|
|
if not search_results: |
|
|
yield format_sse({ |
|
|
"event": "error", |
|
|
"data": "No search results found. Check your query and try again." |
|
|
}) |
|
|
return |
|
|
|
|
|
selected_sources = await filter_and_select_sources(search_results) |
|
|
yield format_sse({ |
|
|
"event": "status", |
|
|
"data": f"Selected {len(selected_sources)} high-quality sources to process." |
|
|
}) |
|
|
yield format_sse({ |
|
|
"event": "selected_sources", |
|
|
"data": selected_sources |
|
|
}) |
|
|
|
|
|
if not selected_sources: |
|
|
yield format_sse({ |
|
|
"event": "error", |
|
|
"data": "No valid sources found after filtering." |
|
|
}) |
|
|
return |
|
|
|
|
|
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) |
|
|
consolidated_context = "" |
|
|
all_sources_used = [] |
|
|
processing_errors = 0 |
|
|
|
|
|
async def process_with_semaphore(source): |
|
|
async with semaphore: |
|
|
return await process_web_source(session, source, timeout=20) |
|
|
|
|
|
processing_tasks = [] |
|
|
for i, source in enumerate(selected_sources): |
|
|
elapsed = time.time() - start_time |
|
|
if elapsed > TOTAL_TIMEOUT * 0.8: |
|
|
yield format_sse({ |
|
|
"event": "status", |
|
|
"data": f"Approaching time limit, stopping source processing at {i}/{len(selected_sources)}" |
|
|
}) |
|
|
break |
|
|
|
|
|
if i > 0: |
|
|
await asyncio.sleep(REQUEST_DELAY * 0.5) |
|
|
|
|
|
task = asyncio.create_task(process_with_semaphore(source)) |
|
|
processing_tasks.append(task) |
|
|
|
|
|
if (i + 1) % 2 == 0 or (i + 1) == len(selected_sources): |
|
|
yield format_sse({ |
|
|
"event": "status", |
|
|
"data": f"Processed {min(i+1, len(selected_sources))}/{len(selected_sources)} sources..." |
|
|
}) |
|
|
|
|
|
for future in asyncio.as_completed(processing_tasks): |
|
|
processed_sources += 1 |
|
|
content, source_info = await future |
|
|
if content and content.strip(): |
|
|
consolidated_context += f"Source: {source_info['link']}\nContent: {content}\n\n---\n\n" |
|
|
all_sources_used.append(source_info) |
|
|
successful_sources += 1 |
|
|
total_tokens += len(content.split()) |
|
|
yield format_sse({ |
|
|
"event": "processed_source", |
|
|
"data": source_info |
|
|
}) |
|
|
else: |
|
|
processing_errors += 1 |
|
|
yield format_sse({ |
|
|
"event": "warning", |
|
|
"data": f"Failed to extract content from {source_info['link']}" |
|
|
}) |
|
|
|
|
|
if not consolidated_context.strip(): |
|
|
yield format_sse({ |
|
|
"event": "error", |
|
|
"data": f"Failed to extract content from any sources. {processing_errors} errors occurred." |
|
|
}) |
|
|
return |
|
|
|
|
|
|
|
|
sources_catalog = [] |
|
|
for idx, s in enumerate(all_sources_used, start=1): |
|
|
title = s.get('title') or s.get('link') |
|
|
sources_catalog.append({ |
|
|
"id": idx, |
|
|
"title": title, |
|
|
"url": s.get('link') |
|
|
}) |
|
|
|
|
|
|
|
|
yield format_sse({ |
|
|
"event": "status", |
|
|
"data": f"Synthesizing a long multi-section report from {successful_sources} sources..." |
|
|
}) |
|
|
|
|
|
sections = [ |
|
|
{"key": "introduction", "title": "1. Introduction and Background", "target_words": 800}, |
|
|
{"key": "features", "title": "2. Key Features and Capabilities", "target_words": 900}, |
|
|
{"key": "comparative", "title": "3. Comparative Analysis with Alternatives", "target_words": 900}, |
|
|
{"key": "trends", "title": "4. Current Developments and Trends", "target_words": 900}, |
|
|
{"key": "challenges", "title": "5. Challenges and Limitations", "target_words": 900}, |
|
|
{"key": "future", "title": "6. Future Outlook", "target_words": 900}, |
|
|
{"key": "conclusion", "title": "7. Conclusion and Recommendations", "target_words": 600}, |
|
|
] |
|
|
|
|
|
|
|
|
preface = ( |
|
|
"You are a meticulous research assistant. Write the requested section in clear, structured markdown. " |
|
|
"Use subheadings, bullet lists, and short paragraphs. Provide deep analysis, data points, and concrete examples. " |
|
|
"When drawing from a listed source, include inline citations like [n] where n is the source number from the catalog. " |
|
|
"Avoid repeating the section title at the top if already included. Do not include a references list inside the section." |
|
|
) |
|
|
|
|
|
catalog_md = "\n".join([f"[{s['id']}] {s['title']} — {s['url']}" for s in sources_catalog]) |
|
|
|
|
|
|
|
|
for sec in sections: |
|
|
if time.time() - start_time > TOTAL_TIMEOUT: |
|
|
yield format_sse({ |
|
|
"event": "warning", |
|
|
"data": "Time limit reached before completing all sections." |
|
|
}) |
|
|
break |
|
|
|
|
|
yield format_sse({"event": "section_start", "data": {"key": sec["key"], "title": sec["title"]}}) |
|
|
|
|
|
section_prompt = f""" |
|
|
{preface} |
|
|
|
|
|
Write the section titled: "{sec['title']}" (aim for ~{sec['target_words']} words, it's okay to exceed if valuable). |
|
|
|
|
|
Topic: "{query}" |
|
|
|
|
|
Sub-questions to consider (optional): |
|
|
{json.dumps(sub_questions, ensure_ascii=False)} |
|
|
|
|
|
Source Catalog (use inline citations like [1], [2]): |
|
|
{catalog_md} |
|
|
|
|
|
Evidence and notes from crawled sources (trimmed): |
|
|
{consolidated_context[:MAX_CONTEXT_SIZE]} |
|
|
""" |
|
|
|
|
|
payload = { |
|
|
"model": LLM_MODEL, |
|
|
"messages": [ |
|
|
{"role": "system", "content": "You are an expert web research analyst and technical writer."}, |
|
|
{"role": "user", "content": section_prompt} |
|
|
], |
|
|
"stream": True, |
|
|
"temperature": 0.6 |
|
|
} |
|
|
|
|
|
try: |
|
|
async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=payload) as response: |
|
|
if response.status != 200: |
|
|
yield format_sse({ |
|
|
"event": "warning", |
|
|
"data": f"Section '{sec['title']}' failed to start (HTTP {response.status}). Skipping." |
|
|
}) |
|
|
continue |
|
|
|
|
|
buffer = "" |
|
|
async for line in response.content: |
|
|
if time.time() - start_time > TOTAL_TIMEOUT: |
|
|
yield format_sse({ |
|
|
"event": "warning", |
|
|
"data": "Time limit reached, halting section generation early." |
|
|
}) |
|
|
break |
|
|
|
|
|
line_str = line.decode('utf-8', errors='ignore').strip() |
|
|
if line_str.startswith('data:'): |
|
|
line_str = line_str[5:].strip() |
|
|
if not line_str: |
|
|
continue |
|
|
if line_str == "[DONE]": |
|
|
if buffer: |
|
|
|
|
|
yield format_sse({"event": "chunk", "data": buffer}) |
|
|
|
|
|
yield format_sse({"event": "section_chunk", "data": {"text": buffer, "section": sec["key"]}}) |
|
|
break |
|
|
try: |
|
|
chunk = json.loads(line_str) |
|
|
choices = chunk.get("choices") |
|
|
if choices and isinstance(choices, list): |
|
|
delta = choices[0].get("delta", {}) |
|
|
content = delta.get("content") |
|
|
if content: |
|
|
buffer += content |
|
|
if len(buffer) >= 400: |
|
|
|
|
|
yield format_sse({"event": "chunk", "data": buffer}) |
|
|
|
|
|
yield format_sse({"event": "section_chunk", "data": {"text": buffer, "section": sec["key"]}}) |
|
|
buffer = "" |
|
|
except json.JSONDecodeError: |
|
|
|
|
|
continue |
|
|
except Exception as e: |
|
|
logging.warning(f"Error processing stream chunk: {e}") |
|
|
continue |
|
|
|
|
|
if buffer: |
|
|
yield format_sse({"event": "chunk", "data": buffer}) |
|
|
yield format_sse({"event": "section_chunk", "data": {"text": buffer, "section": sec["key"]}}) |
|
|
|
|
|
yield format_sse({"event": "section_end", "data": {"key": sec["key"], "title": sec["title"]}}) |
|
|
except Exception as e: |
|
|
yield format_sse({ |
|
|
"event": "warning", |
|
|
"data": f"Section '{sec['title']}' failed: {str(e)[:160]}" |
|
|
}) |
|
|
|
|
|
|
|
|
if sources_catalog: |
|
|
refs_md_lines = ["\n\n## References"] + [ |
|
|
f"[{s['id']}] {s['title']} — {s['url']}" for s in sources_catalog |
|
|
] |
|
|
refs_md = "\n".join(refs_md_lines) |
|
|
|
|
|
yield format_sse({"event": "chunk", "data": refs_md}) |
|
|
|
|
|
yield format_sse({"event": "section_chunk", "data": {"text": refs_md, "section": "references"}}) |
|
|
|
|
|
duration = time.time() - start_time |
|
|
stats = { |
|
|
"total_time_seconds": round(duration), |
|
|
"sources_processed": processed_sources, |
|
|
"sources_successful": successful_sources, |
|
|
"estimated_tokens": total_tokens, |
|
|
"sources_used": len(all_sources_used) |
|
|
} |
|
|
yield format_sse({ |
|
|
"event": "status", |
|
|
"data": f"Research completed successfully in {duration:.1f} seconds." |
|
|
}) |
|
|
yield format_sse({"event": "stats", "data": stats}) |
|
|
yield format_sse({"event": "sources", "data": all_sources_used}) |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
yield format_sse({ |
|
|
"event": "error", |
|
|
"data": f"Research process timed out after {TOTAL_TIMEOUT} seconds." |
|
|
}) |
|
|
except Exception as e: |
|
|
logging.error(f"Critical error in research process: {e}", exc_info=True) |
|
|
yield format_sse({ |
|
|
"event": "error", |
|
|
"data": f"An unexpected error occurred: {str(e)[:200]}" |
|
|
}) |
|
|
finally: |
|
|
duration = time.time() - start_time |
|
|
yield format_sse({ |
|
|
"event": "complete", |
|
|
"data": f"Research process finished after {duration:.1f} seconds." |
|
|
}) |
|
|
|
|
|
@app.post("/deep-research", response_class=StreamingResponse) |
|
|
async def deep_research_endpoint(request: DeepResearchRequest): |
|
|
"""Endpoint for deep research that streams SSE responses.""" |
|
|
if not request.query or len(request.query.strip()) < 3: |
|
|
raise HTTPException(status_code=400, detail="Query must be at least 3 characters long") |
|
|
|
|
|
search_time = min(max(request.search_time, 60), 300) |
|
|
return StreamingResponse( |
|
|
run_deep_research_stream(request.query.strip(), search_time), |
|
|
media_type="text/event-stream", |
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} |
|
|
) |
|
|
|
|
|
@app.post("/v1/search") |
|
|
async def search_only_endpoint(request: SearchRequest): |
|
|
"""Search-only endpoint that returns JSON (no streaming).""" |
|
|
if not request.query or len(request.query.strip()) < 3: |
|
|
raise HTTPException(status_code=400, detail="Query must be at least 3 characters long") |
|
|
|
|
|
|
|
|
search_time = min(max(int(request.search_time), 5), 300) |
|
|
max_results = min(max(int(request.max_results), 1), MAX_SOURCES_TO_PROCESS * 2) |
|
|
|
|
|
aggregated: List[Dict[str, str]] = [] |
|
|
async for update in continuous_search(request.query.strip(), search_time): |
|
|
|
|
|
if update.get("event") == "final_search_results": |
|
|
aggregated = update.get("data", []) |
|
|
|
|
|
|
|
|
dedup: List[Dict[str, str]] = [] |
|
|
seen: set = set() |
|
|
for r in aggregated: |
|
|
link = clean_url(r.get("link", "")) |
|
|
title = r.get("title", "") |
|
|
snippet = r.get("snippet", "") |
|
|
if not link: |
|
|
continue |
|
|
if link in seen: |
|
|
continue |
|
|
seen.add(link) |
|
|
dedup.append({"title": title, "link": link, "snippet": snippet}) |
|
|
if len(dedup) >= max_results: |
|
|
break |
|
|
|
|
|
return { |
|
|
"query": request.query.strip(), |
|
|
"count": len(dedup), |
|
|
"results": dedup, |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|