rkihacker commited on
Commit
5b2a6b6
·
verified ·
1 Parent(s): 0e14740

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +42 -29
main.py CHANGED
@@ -4,7 +4,7 @@ import json
4
  import logging
5
  from typing import AsyncGenerator
6
 
7
- from fastapi import FastAPI, HTTPException, Query
8
  from fastapi.responses import StreamingResponse
9
  from pydantic import BaseModel
10
  from dotenv import load_dotenv
@@ -18,36 +18,39 @@ logger = logging.getLogger(__name__)
18
  load_dotenv()
19
  LLM_API_KEY = os.getenv("LLM_API_KEY")
20
 
21
- # ***** CHANGE 1: Add API Key loading confirmation *****
22
  if not LLM_API_KEY:
23
  raise RuntimeError("LLM_API_KEY must be set in a .env file.")
24
  else:
25
  logger.info(f"LLM API Key loaded successfully (starts with: {LLM_API_KEY[:4]}...).")
26
 
27
- # API URLs, Models, and a new constant for context size
28
  SNAPZION_API_URL = "https://search.snapzion.com/get-snippets"
29
  LLM_API_URL = "https://api.inference.net/v1/chat/completions"
30
- LLM_MODEL = "mistralai/mistral-nemo-12b-instruct/fp-8"
31
- MAX_CONTEXT_CHAR_LENGTH = 120000 # Safeguard: roughly 30k tokens
32
 
33
  # Headers for external services
34
  SNAPZION_HEADERS = { 'Content-Type': 'application/json', 'User-Agent': 'AI-Deep-Research-Agent/1.0' }
35
  SCRAPING_HEADERS = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36' }
36
- LLM_HEADERS = { "Authorization": f"Bearer {LLM_API_KEY}", "Content-Type": "application/json" }
37
-
38
- # --- Pydantic Models for Request Body ---
 
 
 
 
 
39
  class DeepResearchRequest(BaseModel):
40
  query: str
41
 
42
- # --- FastAPI App Initialization ---
43
  app = FastAPI(
44
  title="AI Deep Research API",
45
- description="Provides single-shot AI search and streaming deep research completions.",
46
- version="2.1.0" # Version bump for new robustness feature
47
  )
48
 
49
  # --- Core Service Functions (Unchanged) ---
50
-
51
  async def call_snapzion_search(session: aiohttp.ClientSession, query: str) -> list:
52
  try:
53
  async with session.post(SNAPZION_API_URL, headers=SNAPZION_HEADERS, json={"query": query}, timeout=15) as response:
@@ -76,10 +79,8 @@ async def search_and_scrape(session: aiohttp.ClientSession, query: str) -> tuple
76
  search_results = await call_snapzion_search(session, query)
77
  sources = search_results[:4]
78
  if not sources: return "", []
79
-
80
  scrape_tasks = [scrape_url(session, source["link"]) for source in sources]
81
  scraped_contents = await asyncio.gather(*scrape_tasks)
82
-
83
  context = "\n\n".join(
84
  f"Source Details: Title '{sources[i]['title']}', URL '{sources[i]['link']}'\nContent:\n{content}"
85
  for i, content in enumerate(scraped_contents) if not content.startswith("Error:")
@@ -95,20 +96,38 @@ async def run_deep_research_stream(query: str) -> AsyncGenerator[str, None]:
95
 
96
  try:
97
  async with aiohttp.ClientSession() as session:
98
- # Step 1: Generate Sub-Questions (Unchanged)
99
  yield format_sse({"event": "status", "data": "Generating research plan..."})
 
100
  sub_question_prompt = {
101
  "model": LLM_MODEL,
102
- "messages": [{ "role": "user", "content": f"You are a research planner. For the topic '{query}', create a JSON array of 3-4 key sub-questions for a research report. Example: [\"Question 1?\", \"Question 2?\"]" }]
103
  }
104
- async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=sub_question_prompt) as response:
105
- response.raise_for_status()
106
- result = await response.json()
107
- sub_questions = json.loads(result['choices'][0]['message']['content'])
108
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  yield format_sse({"event": "plan", "data": sub_questions})
110
 
111
- # Step 2: Concurrently research all sub-questions (Unchanged)
 
112
  research_tasks = [search_and_scrape(session, sq) for sq in sub_questions]
113
  all_research_results = []
114
 
@@ -123,8 +142,6 @@ async def run_deep_research_stream(query: str) -> AsyncGenerator[str, None]:
123
  all_sources = [source for res in all_research_results for source in res[1]]
124
  unique_sources = list({s['link']: s for s in all_sources}.values())
125
 
126
- # ***** CHANGE 2: Implement the context truncation safeguard *****
127
- logger.info(f"Consolidated context size: {len(full_context)} characters.")
128
  if len(full_context) > MAX_CONTEXT_CHAR_LENGTH:
129
  logger.warning(f"Context is too long. Truncating from {len(full_context)} to {MAX_CONTEXT_CHAR_LENGTH} characters.")
130
  full_context = full_context[:MAX_CONTEXT_CHAR_LENGTH]
@@ -140,14 +157,11 @@ async def run_deep_research_stream(query: str) -> AsyncGenerator[str, None]:
140
  final_report_payload = {"model": LLM_MODEL, "messages": [{"role": "user", "content": final_report_prompt}], "stream": True}
141
 
142
  async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=final_report_payload) as response:
143
- # ***** CHANGE 3: More robust error handling for the streaming call *****
144
  if response.status != 200:
145
  error_text = await response.text()
146
- logger.error(f"LLM API returned a non-200 status: {response.status} - {error_text}")
147
- raise Exception(f"LLM API Error: {response.status}, {error_text}")
148
 
149
  async for line in response.content:
150
- # (Rest of the streaming logic is the same)
151
  if line.strip():
152
  line_str = line.decode('utf-8').strip()
153
  if line_str.startswith('data:'): line_str = line_str[5:].strip()
@@ -166,7 +180,6 @@ async def run_deep_research_stream(query: str) -> AsyncGenerator[str, None]:
166
  finally:
167
  yield format_sse({"event": "done", "data": "Deep research complete."})
168
 
169
-
170
  # --- API Endpoints ---
171
  @app.post("/v1/deepresearch/completions")
172
  async def deep_research_endpoint(request: DeepResearchRequest):
 
4
  import logging
5
  from typing import AsyncGenerator
6
 
7
+ from fastapi import FastAPI
8
  from fastapi.responses import StreamingResponse
9
  from pydantic import BaseModel
10
  from dotenv import load_dotenv
 
18
  load_dotenv()
19
  LLM_API_KEY = os.getenv("LLM_API_KEY")
20
 
 
21
  if not LLM_API_KEY:
22
  raise RuntimeError("LLM_API_KEY must be set in a .env file.")
23
  else:
24
  logger.info(f"LLM API Key loaded successfully (starts with: {LLM_API_KEY[:4]}...).")
25
 
26
+ # API URLs, Models, and context size limit
27
  SNAPZION_API_URL = "https://search.snapzion.com/get-snippets"
28
  LLM_API_URL = "https://api.inference.net/v1/chat/completions"
29
+ LLM_MODEL = "mistralai/mistral-nemo-12b-instruct/fp-8" # Corrected model name from previous attempts
30
+ MAX_CONTEXT_CHAR_LENGTH = 120000
31
 
32
  # Headers for external services
33
  SNAPZION_HEADERS = { 'Content-Type': 'application/json', 'User-Agent': 'AI-Deep-Research-Agent/1.0' }
34
  SCRAPING_HEADERS = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36' }
35
+ # ***** CHANGE 1: Add a User-Agent to the LLM headers *****
36
+ LLM_HEADERS = {
37
+ "Authorization": f"Bearer {LLM_API_KEY}",
38
+ "Content-Type": "application/json",
39
+ "User-Agent": "AI-Deep-Research-Client/2.2"
40
+ }
41
+
42
+ # --- Pydantic Models ---
43
  class DeepResearchRequest(BaseModel):
44
  query: str
45
 
46
+ # --- FastAPI App ---
47
  app = FastAPI(
48
  title="AI Deep Research API",
49
+ description="Provides streaming deep research completions.",
50
+ version="2.2.0" # Version bump for critical bug fix
51
  )
52
 
53
  # --- Core Service Functions (Unchanged) ---
 
54
  async def call_snapzion_search(session: aiohttp.ClientSession, query: str) -> list:
55
  try:
56
  async with session.post(SNAPZION_API_URL, headers=SNAPZION_HEADERS, json={"query": query}, timeout=15) as response:
 
79
  search_results = await call_snapzion_search(session, query)
80
  sources = search_results[:4]
81
  if not sources: return "", []
 
82
  scrape_tasks = [scrape_url(session, source["link"]) for source in sources]
83
  scraped_contents = await asyncio.gather(*scrape_tasks)
 
84
  context = "\n\n".join(
85
  f"Source Details: Title '{sources[i]['title']}', URL '{sources[i]['link']}'\nContent:\n{content}"
86
  for i, content in enumerate(scraped_contents) if not content.startswith("Error:")
 
96
 
97
  try:
98
  async with aiohttp.ClientSession() as session:
99
+ # Step 1: Generate Sub-Questions
100
  yield format_sse({"event": "status", "data": "Generating research plan..."})
101
+
102
  sub_question_prompt = {
103
  "model": LLM_MODEL,
104
+ "messages": [{ "role": "user", "content": f"You are a research planner. For the topic '{query}', create a JSON array of 3-4 key sub-questions for a research report. Respond ONLY with the JSON array. Example: [\"Question 1?\", \"Question 2?\"]" }]
105
  }
106
+
107
+ # ***** CHANGE 2: Implement robust parsing for the API call *****
108
+ try:
109
+ async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=sub_question_prompt, timeout=20) as response:
110
+ if response.status != 200:
111
+ error_text = await response.text()
112
+ logger.error(f"LLM API for planning failed with status {response.status}: {error_text}")
113
+ raise Exception(f"LLM API returned non-200 status: {response.status}")
114
+
115
+ raw_response_text = await response.text()
116
+ if not raw_response_text:
117
+ raise Exception("LLM API returned an empty response.")
118
+
119
+ result = json.loads(raw_response_text)
120
+ llm_content = result['choices'][0]['message']['content']
121
+ sub_questions = json.loads(llm_content)
122
+ except Exception as e:
123
+ logger.error(f"Failed to generate or parse research plan: {e}")
124
+ yield format_sse({"event": "error", "data": f"Could not generate research plan. Reason: {e}"})
125
+ return # Stop the process if planning fails
126
+
127
  yield format_sse({"event": "plan", "data": sub_questions})
128
 
129
+ # (The rest of the logic remains the same)
130
+ # Step 2: Concurrently research all sub-questions
131
  research_tasks = [search_and_scrape(session, sq) for sq in sub_questions]
132
  all_research_results = []
133
 
 
142
  all_sources = [source for res in all_research_results for source in res[1]]
143
  unique_sources = list({s['link']: s for s in all_sources}.values())
144
 
 
 
145
  if len(full_context) > MAX_CONTEXT_CHAR_LENGTH:
146
  logger.warning(f"Context is too long. Truncating from {len(full_context)} to {MAX_CONTEXT_CHAR_LENGTH} characters.")
147
  full_context = full_context[:MAX_CONTEXT_CHAR_LENGTH]
 
157
  final_report_payload = {"model": LLM_MODEL, "messages": [{"role": "user", "content": final_report_prompt}], "stream": True}
158
 
159
  async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=final_report_payload) as response:
 
160
  if response.status != 200:
161
  error_text = await response.text()
162
+ raise Exception(f"LLM API Error for final report: {response.status}, {error_text}")
 
163
 
164
  async for line in response.content:
 
165
  if line.strip():
166
  line_str = line.decode('utf-8').strip()
167
  if line_str.startswith('data:'): line_str = line_str[5:].strip()
 
180
  finally:
181
  yield format_sse({"event": "done", "data": "Deep research complete."})
182
 
 
183
  # --- API Endpoints ---
184
  @app.post("/v1/deepresearch/completions")
185
  async def deep_research_endpoint(request: DeepResearchRequest):