Spaces:
Paused
Paused
| import json | |
| import os | |
| import warnings | |
| from typing import Dict, Any, List, Union, Type, get_origin, get_args | |
| from .variables.default import DEFAULT_CONFIG | |
| from .variables.base import BaseConfig | |
| from ..retrievers.utils import get_all_retriever_names | |
| class Config: | |
| """Config class for GPT Researcher.""" | |
| CONFIG_DIR = os.path.join(os.path.dirname(__file__), "variables") | |
| def __init__(self, config_path: str | None = None): | |
| """Initialize the config class.""" | |
| self.config_path = config_path | |
| self.llm_kwargs: Dict[str, Any] = {} | |
| self.embedding_kwargs: Dict[str, Any] = {} | |
| config_to_use = self.load_config(config_path) | |
| self._set_attributes(config_to_use) | |
| self._set_embedding_attributes() | |
| self._set_llm_attributes() | |
| self._handle_deprecated_attributes() | |
| self._set_doc_path(config_to_use) | |
| def _set_attributes(self, config: Dict[str, Any]) -> None: | |
| for key, value in config.items(): | |
| env_value = os.getenv(key) | |
| if env_value is not None: | |
| value = self.convert_env_value(key, env_value, BaseConfig.__annotations__[key]) | |
| setattr(self, key.lower(), value) | |
| # Handle RETRIEVER with default value | |
| retriever_env = os.environ.get("RETRIEVER", config.get("RETRIEVER", "tavily")) | |
| try: | |
| self.retrievers = self.parse_retrievers(retriever_env) | |
| except ValueError as e: | |
| print(f"Warning: {str(e)}. Defaulting to 'tavily' retriever.") | |
| self.retrievers = ["tavily"] | |
| def _set_embedding_attributes(self) -> None: | |
| self.embedding_provider, self.embedding_model = self.parse_embedding( | |
| self.embedding | |
| ) | |
| def _set_llm_attributes(self) -> None: | |
| self.fast_llm_provider, self.fast_llm_model = self.parse_llm(self.fast_llm) | |
| self.smart_llm_provider, self.smart_llm_model = self.parse_llm(self.smart_llm) | |
| self.strategic_llm_provider, self.strategic_llm_model = self.parse_llm(self.strategic_llm) | |
| def _handle_deprecated_attributes(self) -> None: | |
| if os.getenv("EMBEDDING_PROVIDER") is not None: | |
| warnings.warn( | |
| "EMBEDDING_PROVIDER is deprecated and will be removed soon. Use EMBEDDING instead.", | |
| FutureWarning, | |
| stacklevel=2, | |
| ) | |
| self.embedding_provider = ( | |
| os.environ["EMBEDDING_PROVIDER"] or self.embedding_provider | |
| ) | |
| match os.environ["EMBEDDING_PROVIDER"]: | |
| case "ollama": | |
| self.embedding_model = os.environ["OLLAMA_EMBEDDING_MODEL"] | |
| case "custom": | |
| self.embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL", "custom") | |
| case "openai": | |
| self.embedding_model = "text-embedding-3-large" | |
| case "azure_openai": | |
| self.embedding_model = "text-embedding-3-large" | |
| case "huggingface": | |
| self.embedding_model = "sentence-transformers/all-MiniLM-L6-v2" | |
| case _: | |
| raise Exception("Embedding provider not found.") | |
| _deprecation_warning = ( | |
| "LLM_PROVIDER, FAST_LLM_MODEL and SMART_LLM_MODEL are deprecated and " | |
| "will be removed soon. Use FAST_LLM and SMART_LLM instead." | |
| ) | |
| if os.getenv("LLM_PROVIDER") is not None: | |
| warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2) | |
| self.fast_llm_provider = ( | |
| os.environ["LLM_PROVIDER"] or self.fast_llm_provider | |
| ) | |
| self.smart_llm_provider = ( | |
| os.environ["LLM_PROVIDER"] or self.smart_llm_provider | |
| ) | |
| if os.getenv("FAST_LLM_MODEL") is not None: | |
| warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2) | |
| self.fast_llm_model = os.environ["FAST_LLM_MODEL"] or self.fast_llm_model | |
| if os.getenv("SMART_LLM_MODEL") is not None: | |
| warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2) | |
| self.smart_llm_model = os.environ["SMART_LLM_MODEL"] or self.smart_llm_model | |
| def _set_doc_path(self, config: Dict[str, Any]) -> None: | |
| self.doc_path = config['DOC_PATH'] | |
| if self.doc_path: | |
| try: | |
| self.validate_doc_path() | |
| except Exception as e: | |
| print(f"Warning: Error validating doc_path: {str(e)}. Using default doc_path.") | |
| self.doc_path = DEFAULT_CONFIG['DOC_PATH'] | |
| def load_config(cls, config_path: str | None) -> Dict[str, Any]: | |
| """Load a configuration by name.""" | |
| if config_path is None: | |
| return DEFAULT_CONFIG | |
| # config_path = os.path.join(cls.CONFIG_DIR, config_path) | |
| if not os.path.exists(config_path): | |
| if config_path and config_path != "default": | |
| print(f"Warning: Configuration not found at '{config_path}'. Using default configuration.") | |
| if not config_path.endswith(".json"): | |
| print(f"Do you mean '{config_path}.json'?") | |
| return DEFAULT_CONFIG | |
| with open(config_path, "r") as f: | |
| custom_config = json.load(f) | |
| # Merge with default config to ensure all keys are present | |
| merged_config = DEFAULT_CONFIG.copy() | |
| merged_config.update(custom_config) | |
| return merged_config | |
| def list_available_configs(cls) -> List[str]: | |
| """List all available configuration names.""" | |
| configs = ["default"] | |
| for file in os.listdir(cls.CONFIG_DIR): | |
| if file.endswith(".json"): | |
| configs.append(file[:-5]) # Remove .json extension | |
| return configs | |
| def parse_retrievers(self, retriever_str: str) -> List[str]: | |
| """Parse the retriever string into a list of retrievers and validate them.""" | |
| retrievers = [retriever.strip() | |
| for retriever in retriever_str.split(",")] | |
| valid_retrievers = get_all_retriever_names() or [] | |
| invalid_retrievers = [r for r in retrievers if r not in valid_retrievers] | |
| if invalid_retrievers: | |
| raise ValueError( | |
| f"Invalid retriever(s) found: {', '.join(invalid_retrievers)}. " | |
| f"Valid options are: {', '.join(valid_retrievers)}." | |
| ) | |
| return retrievers | |
| def parse_llm(llm_str: str | None) -> tuple[str | None, str | None]: | |
| """Parse llm string into (llm_provider, llm_model).""" | |
| from gpt_researcher.llm_provider.generic.base import _SUPPORTED_PROVIDERS | |
| if llm_str is None: | |
| return None, None | |
| try: | |
| llm_provider, llm_model = llm_str.split(":", 1) | |
| assert llm_provider in _SUPPORTED_PROVIDERS, ( | |
| f"Unsupported {llm_provider}.\nSupported llm providers are: " | |
| + ", ".join(_SUPPORTED_PROVIDERS) | |
| ) | |
| return llm_provider, llm_model | |
| except ValueError: | |
| raise ValueError( | |
| "Set SMART_LLM or FAST_LLM = '<llm_provider>:<llm_model>' " | |
| "Eg 'openai:gpt-4o-mini'" | |
| ) | |
| def parse_embedding(embedding_str: str | None) -> tuple[str | None, str | None]: | |
| """Parse embedding string into (embedding_provider, embedding_model).""" | |
| from gpt_researcher.memory.embeddings import _SUPPORTED_PROVIDERS | |
| if embedding_str is None: | |
| return None, None | |
| try: | |
| embedding_provider, embedding_model = embedding_str.split(":", 1) | |
| assert embedding_provider in _SUPPORTED_PROVIDERS, ( | |
| f"Unsupported {embedding_provider}.\nSupported embedding providers are: " | |
| + ", ".join(_SUPPORTED_PROVIDERS) | |
| ) | |
| return embedding_provider, embedding_model | |
| except ValueError: | |
| raise ValueError( | |
| "Set EMBEDDING = '<embedding_provider>:<embedding_model>' " | |
| "Eg 'openai:text-embedding-3-large'" | |
| ) | |
| def validate_doc_path(self): | |
| """Ensure that the folder exists at the doc path""" | |
| os.makedirs(self.doc_path, exist_ok=True) | |
| def convert_env_value(key: str, env_value: str, type_hint: Type) -> Any: | |
| """Convert environment variable to the appropriate type based on the type hint.""" | |
| origin = get_origin(type_hint) | |
| args = get_args(type_hint) | |
| if origin is Union: | |
| # Handle Union types (e.g., Union[str, None]) | |
| for arg in args: | |
| if arg is type(None): | |
| if env_value.lower() in ("none", "null", ""): | |
| return None | |
| else: | |
| try: | |
| return Config.convert_env_value(key, env_value, arg) | |
| except ValueError: | |
| continue | |
| raise ValueError(f"Cannot convert {env_value} to any of {args}") | |
| if type_hint is bool: | |
| return env_value.lower() in ("true", "1", "yes", "on") | |
| elif type_hint is int: | |
| return int(env_value) | |
| elif type_hint is float: | |
| return float(env_value) | |
| elif type_hint in (str, Any): | |
| return env_value | |
| elif origin is list or origin is List: | |
| return json.loads(env_value) | |
| else: | |
| raise ValueError(f"Unsupported type {type_hint} for key {key}") | |