Spaces:
Paused
Paused
| from typing import Optional, List, Dict, Any, Set | |
| import json | |
| from .config import Config | |
| from .memory import Memory | |
| from .utils.enum import ReportSource, ReportType, Tone | |
| from .llm_provider import GenericLLMProvider | |
| from .vector_store import VectorStoreWrapper | |
| # Research skills | |
| from .skills.researcher import ResearchConductor | |
| from .skills.writer import ReportGenerator | |
| from .skills.context_manager import ContextManager | |
| from .skills.browser import BrowserManager | |
| from .skills.curator import SourceCurator | |
| from .actions import ( | |
| add_references, | |
| extract_headers, | |
| extract_sections, | |
| table_of_contents, | |
| get_retrievers, | |
| choose_agent | |
| ) | |
| class GPTResearcher: | |
| def __init__( | |
| self, | |
| query: str, | |
| report_type: str = ReportType.ResearchReport.value, | |
| report_format: str = "markdown", | |
| report_source: str = ReportSource.Web.value, | |
| tone: Tone = Tone.Objective, | |
| source_urls=None, | |
| document_urls=None, | |
| complement_source_urls=False, | |
| documents=None, | |
| vector_store=None, | |
| vector_store_filter=None, | |
| config_path=None, | |
| websocket=None, | |
| agent=None, | |
| role=None, | |
| parent_query: str = "", | |
| subtopics: list = [], | |
| visited_urls: set = set(), | |
| verbose: bool = True, | |
| context=[], | |
| headers: dict = None, | |
| max_subtopics: int = 5, | |
| log_handler=None, | |
| ): | |
| self.query = query | |
| self.report_type = report_type | |
| self.cfg = Config(config_path) | |
| self.llm = GenericLLMProvider(self.cfg) | |
| self.report_source = report_source if report_source else getattr(self.cfg, 'report_source', None) | |
| self.report_format = report_format | |
| self.max_subtopics = max_subtopics | |
| self.tone = tone if isinstance(tone, Tone) else Tone.Objective | |
| self.source_urls = source_urls | |
| self.document_urls = document_urls | |
| self.complement_source_urls: bool = complement_source_urls | |
| self.research_sources = [] # The list of scraped sources including title, content and images | |
| self.research_images = [] # The list of selected research images | |
| self.documents = documents | |
| self.vector_store = VectorStoreWrapper(vector_store) if vector_store else None | |
| self.vector_store_filter = vector_store_filter | |
| self.websocket = websocket | |
| self.agent = agent | |
| self.role = role | |
| self.parent_query = parent_query | |
| self.subtopics = subtopics | |
| self.visited_urls = visited_urls | |
| self.verbose = verbose | |
| self.context = context | |
| self.headers = headers or {} | |
| self.research_costs = 0.0 | |
| self.retrievers = get_retrievers(self.headers, self.cfg) | |
| self.memory = Memory( | |
| self.cfg.embedding_provider, self.cfg.embedding_model, **self.cfg.embedding_kwargs | |
| ) | |
| self.log_handler = log_handler | |
| # Initialize components | |
| self.research_conductor: ResearchConductor = ResearchConductor(self) | |
| self.report_generator: ReportGenerator = ReportGenerator(self) | |
| self.context_manager: ContextManager = ContextManager(self) | |
| self.scraper_manager: BrowserManager = BrowserManager(self) | |
| self.source_curator: SourceCurator = SourceCurator(self) | |
| async def _log_event(self, event_type: str, **kwargs): | |
| """Helper method to handle logging events""" | |
| if self.log_handler: | |
| try: | |
| if event_type == "tool": | |
| await self.log_handler.on_tool_start(kwargs.get('tool_name', ''), **kwargs) | |
| elif event_type == "action": | |
| await self.log_handler.on_agent_action(kwargs.get('action', ''), **kwargs) | |
| elif event_type == "research": | |
| await self.log_handler.on_research_step(kwargs.get('step', ''), kwargs.get('details', {})) | |
| # Add direct logging as backup | |
| import logging | |
| research_logger = logging.getLogger('research') | |
| research_logger.info(f"{event_type}: {json.dumps(kwargs, default=str)}") | |
| except Exception as e: | |
| import logging | |
| logging.getLogger('research').error(f"Error in _log_event: {e}", exc_info=True) | |
| async def conduct_research(self): | |
| await self._log_event("research", step="start", details={ | |
| "query": self.query, | |
| "report_type": self.report_type, | |
| "agent": self.agent, | |
| "role": self.role | |
| }) | |
| if not (self.agent and self.role): | |
| await self._log_event("action", action="choose_agent") | |
| self.agent, self.role = await choose_agent( | |
| query=self.query, | |
| cfg=self.cfg, | |
| parent_query=self.parent_query, | |
| cost_callback=self.add_costs, | |
| headers=self.headers, | |
| ) | |
| await self._log_event("action", action="agent_selected", details={ | |
| "agent": self.agent, | |
| "role": self.role | |
| }) | |
| await self._log_event("research", step="conducting_research", details={ | |
| "agent": self.agent, | |
| "role": self.role | |
| }) | |
| self.context = await self.research_conductor.conduct_research() | |
| await self._log_event("research", step="research_completed", details={ | |
| "context_length": len(self.context) | |
| }) | |
| return self.context | |
| async def write_report(self, existing_headers: list = [], relevant_written_contents: list = [], ext_context=None) -> str: | |
| await self._log_event("research", step="writing_report", details={ | |
| "existing_headers": existing_headers, | |
| "context_source": "external" if ext_context else "internal" | |
| }) | |
| report = await self.report_generator.write_report( | |
| existing_headers, | |
| relevant_written_contents, | |
| ext_context or self.context | |
| ) | |
| await self._log_event("research", step="report_completed", details={ | |
| "report_length": len(report) | |
| }) | |
| return report | |
| async def write_report_conclusion(self, report_body: str) -> str: | |
| await self._log_event("research", step="writing_conclusion") | |
| conclusion = await self.report_generator.write_report_conclusion(report_body) | |
| await self._log_event("research", step="conclusion_completed") | |
| return conclusion | |
| async def write_introduction(self): | |
| await self._log_event("research", step="writing_introduction") | |
| intro = await self.report_generator.write_introduction() | |
| await self._log_event("research", step="introduction_completed") | |
| return intro | |
| async def get_subtopics(self): | |
| return await self.report_generator.get_subtopics() | |
| async def get_draft_section_titles(self, current_subtopic: str): | |
| return await self.report_generator.get_draft_section_titles(current_subtopic) | |
| async def get_similar_written_contents_by_draft_section_titles( | |
| self, | |
| current_subtopic: str, | |
| draft_section_titles: List[str], | |
| written_contents: List[Dict], | |
| max_results: int = 10 | |
| ) -> List[str]: | |
| return await self.context_manager.get_similar_written_contents_by_draft_section_titles( | |
| current_subtopic, | |
| draft_section_titles, | |
| written_contents, | |
| max_results | |
| ) | |
| # Utility methods | |
| def get_research_images(self, top_k=10) -> List[Dict[str, Any]]: | |
| return self.research_images[:top_k] | |
| def add_research_images(self, images: List[Dict[str, Any]]) -> None: | |
| self.research_images.extend(images) | |
| def get_research_sources(self) -> List[Dict[str, Any]]: | |
| return self.research_sources | |
| def add_research_sources(self, sources: List[Dict[str, Any]]) -> None: | |
| self.research_sources.extend(sources) | |
| def add_references(self, report_markdown: str, visited_urls: set) -> str: | |
| return add_references(report_markdown, visited_urls) | |
| def extract_headers(self, markdown_text: str) -> List[Dict]: | |
| return extract_headers(markdown_text) | |
| def extract_sections(self, markdown_text: str) -> List[Dict]: | |
| return extract_sections(markdown_text) | |
| def table_of_contents(self, markdown_text: str) -> str: | |
| return table_of_contents(markdown_text) | |
| def get_source_urls(self) -> list: | |
| return list(self.visited_urls) | |
| def get_research_context(self) -> list: | |
| return self.context | |
| def get_costs(self) -> float: | |
| return self.research_costs | |
| def set_verbose(self, verbose: bool): | |
| self.verbose = verbose | |
| def add_costs(self, cost: float) -> None: | |
| if not isinstance(cost, (float, int)): | |
| raise ValueError("Cost must be an integer or float") | |
| self.research_costs += cost | |
| if self.log_handler: | |
| self._log_event("research", step="cost_update", details={ | |
| "cost": cost, | |
| "total_cost": self.research_costs | |
| }) | |