File size: 9,026 Bytes
4b17916 2a0098d 6142af3 2a0098d 6142af3 2a0098d 4b17916 2a0098d 4b17916 e1111e0 2a0098d 6142af3 2a0098d 4b17916 6142af3 2a0098d 4b17916 6142af3 4b17916 6142af3 2a0098d 4b17916 2a0098d 4b17916 2a0098d 6142af3 2a0098d 6142af3 4b17916 e1111e0 6142af3 2a0098d 6142af3 2a0098d 6142af3 2a0098d 6142af3 4b17916 6142af3 2a0098d 6142af3 2a0098d 6142af3 2a0098d 6142af3 2a0098d 6142af3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import os
import asyncio
import json
import logging
from typing import AsyncGenerator
from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from dotenv import load_dotenv
import aiohttp
from bs4 import BeautifulSoup
# --- Configuration ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
LLM_API_KEY = os.getenv("LLM_API_KEY")
if not LLM_API_KEY:
raise RuntimeError("LLM_API_KEY must be set in a .env file.")
# API URLs and Models
SNAPZION_API_URL = "https://search.snapzion.com/get-snippets"
LLM_API_URL = "https://api.inference.net/v1/chat/completions"
LLM_MODEL = "meta-llama/llama-3.1-8b-instruct/fp-8"
# Headers for external services
SNAPZION_HEADERS = { 'Content-Type': 'application/json', 'User-Agent': 'AI-Deep-Research-Agent/1.0' }
SCRAPING_HEADERS = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36' }
LLM_HEADERS = { "Authorization": f"Bearer {LLM_API_KEY}", "Content-Type": "application/json" }
# --- Pydantic Models for Request Body ---
class DeepResearchRequest(BaseModel):
query: str
# --- FastAPI App Initialization ---
app = FastAPI(
title="AI Deep Research API",
description="Provides single-shot AI search and streaming deep research completions.",
version="2.0.0"
)
# --- Core Service Functions (Reused and New) ---
async def call_snapzion_search(session: aiohttp.ClientSession, query: str) -> list:
try:
async with session.post(SNAPZION_API_URL, headers=SNAPZION_HEADERS, json={"query": query}, timeout=15) as response:
response.raise_for_status()
data = await response.json()
return data.get("organic_results", [])
except Exception as e:
logger.error(f"Snapzion search failed for query '{query}': {e}")
return [] # Return empty list on failure instead of crashing
async def scrape_url(session: aiohttp.ClientSession, url: str) -> str:
if url.lower().endswith('.pdf'): return "Error: PDF content cannot be scraped."
try:
async with session.get(url, headers=SCRAPING_HEADERS, timeout=10, ssl=False) as response:
if response.status != 200: return f"Error: HTTP status {response.status}"
html = await response.text()
soup = BeautifulSoup(html, "html.parser")
for tag in soup(['script', 'style', 'nav', 'footer', 'header', 'aside']):
tag.decompose()
return " ".join(soup.stripped_strings)
except Exception as e:
logger.warning(f"Scraping failed for {url}: {e}")
return f"Error: {e}"
async def search_and_scrape(session: aiohttp.ClientSession, query: str) -> tuple[str, list]:
"""Performs the search and scrape pipeline for a given query."""
search_results = await call_snapzion_search(session, query)
sources = search_results[:4] # Use top 4 sources per sub-query
if not sources: return "", []
scrape_tasks = [scrape_url(session, source["link"]) for source in sources]
scraped_contents = await asyncio.gather(*scrape_tasks)
context = "\n\n".join(
f"Source [{i+1}] (from {sources[i]['link']}):\n{content}"
for i, content in enumerate(scraped_contents) if not content.startswith("Error:")
)
return context, sources
# --- Streaming Deep Research Logic ---
async def run_deep_research_stream(query: str) -> AsyncGenerator[str, None]:
"""The main async generator for the deep research process."""
def format_sse(data: dict) -> str:
"""Formats a dictionary as a Server-Sent Event string."""
return f"data: {json.dumps(data)}\n\n"
try:
async with aiohttp.ClientSession() as session:
# Step 1: Generate Sub-Questions
yield format_sse({"event": "status", "data": "Generating research plan..."})
sub_question_prompt = {
"model": LLM_MODEL,
"messages": [{
"role": "user",
"content": f"You are a research planner. Based on the user's query '{query}', generate a list of 3 to 4 crucial sub-questions that would form the basis of a comprehensive research report. Respond with ONLY a JSON array of strings. Example: [\"Question 1?\", \"Question 2?\"]"
}]
}
async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=sub_question_prompt) as response:
response.raise_for_status()
result = await response.json()
try:
sub_questions = json.loads(result['choices'][0]['message']['content'])
except (json.JSONDecodeError, IndexError):
yield format_sse({"event": "error", "data": "Failed to parse sub-questions from LLM."})
return
yield format_sse({"event": "plan", "data": sub_questions})
# Step 2: Concurrently research all sub-questions
research_tasks = [search_and_scrape(session, sq) for sq in sub_questions]
all_research_results = []
for i, task in enumerate(asyncio.as_completed(research_tasks)):
yield format_sse({"event": "status", "data": f"Researching: \"{sub_questions[i]}\""})
result = await task
all_research_results.append(result)
# Step 3: Consolidate all context and sources
yield format_sse({"event": "status", "data": "Consolidating research..."})
full_context = "\n\n---\n\n".join(res[0] for res in all_research_results if res[0])
all_sources = [source for res in all_research_results for source in res[1]]
unique_sources = list({s['link']: s for s in all_sources}.values()) # Deduplicate sources
if not full_context.strip():
yield format_sse({"event": "error", "data": "Failed to gather any research context."})
return
# Step 4: Generate the final report with streaming
yield format_sse({"event": "status", "data": "Generating final report..."})
final_report_prompt = f"""
You are a research analyst. Your task is to synthesize the provided context into a comprehensive, well-structured report on the topic: "{query}".
Use the context below exclusively. Do not use outside knowledge. Structure the report with markdown headings.
## Research Context ##
{full_context}
"""
final_report_payload = {
"model": LLM_MODEL,
"messages": [{"role": "user", "content": final_report_prompt}],
"stream": True # Enable streaming from the LLM
}
async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=final_report_payload) as response:
response.raise_for_status()
async for line in response.content:
if line.strip():
# The inference API might wrap its stream chunks in a 'data: ' prefix
line_str = line.decode('utf-8').strip()
if line_str.startswith('data:'):
line_str = line_str[5:].strip()
if line_str == "[DONE]":
break
try:
chunk = json.loads(line_str)
content = chunk.get("choices", [{}])[0].get("delta", {}).get("content")
if content:
yield format_sse({"event": "chunk", "data": content})
except json.JSONDecodeError:
continue # Ignore empty or malformed lines
yield format_sse({"event": "sources", "data": unique_sources})
except Exception as e:
logger.error(f"An error occurred during deep research: {e}")
yield format_sse({"event": "error", "data": str(e)})
finally:
yield format_sse({"event": "done", "data": "Deep research complete."})
# --- API Endpoints ---
@app.get("/", include_in_schema=False)
def root():
return {"message": "AI Deep Research API is active. See /docs for details."}
@app.post("/v1/deepresearch/completions")
async def deep_research_endpoint(request: DeepResearchRequest):
"""
Performs a multi-step, streaming deep research task.
**Events Streamed:**
- `status`: Provides updates on the current stage of the process.
- `plan`: The list of sub-questions that will be researched.
- `chunk`: A piece of the final generated report.
- `sources`: The list of web sources used for the report.
- `error`: Indicates a fatal error occurred.
- `done`: Signals the end of the stream.
"""
return StreamingResponse(
run_deep_research_stream(request.query),
media_type="text/event-stream"
) |