Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import time | |
| import logging | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from .config import settings | |
| logger = logging.getLogger(__name__) | |
| # Store initialized models to avoid re-creating them repeatedly | |
| _llm_cache = {} | |
| def get_llm(model_name: str) -> BaseChatModel: | |
| """ | |
| Returns an initialized LangChain chat model based on the provided name. | |
| Caches initialized models. | |
| """ | |
| global _llm_cache | |
| if model_name in _llm_cache: | |
| return _llm_cache[model_name] | |
| logger.info(f"Initializing LLM: {model_name}") | |
| if model_name.startswith("gemini"): | |
| if not settings.gemini_api_key: | |
| raise ValueError("GEMINI_API_KEY is not configured.") | |
| try: | |
| # Use GOOGLE_API_KEY environment variable set in config.py | |
| llm = ChatGoogleGenerativeAI(model=model_name) | |
| _llm_cache[model_name] = llm | |
| logger.info(f"Initialized Google Generative AI model: {model_name}") | |
| return llm | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Gemini model '{model_name}': {e}", exc_info=True) | |
| raise RuntimeError(f"Could not initialize Gemini model: {e}") from e | |
| elif model_name.startswith("gpt"): | |
| if not settings.openai_api_key: | |
| raise ValueError("OPENAI_API_KEY is not configured.") | |
| try: | |
| # Base URL can be added here if using a proxy | |
| # base_url = "https://your-proxy-if-needed/" | |
| llm = ChatOpenAI(model=model_name, api_key=settings.openai_api_key) # Base URL optional | |
| _llm_cache[model_name] = llm | |
| logger.info(f"Initialized OpenAI model: {model_name}") | |
| return llm | |
| except Exception as e: | |
| logger.error(f"Failed to initialize OpenAI model '{model_name}': {e}", exc_info=True) | |
| raise RuntimeError(f"Could not initialize OpenAI model: {e}") from e | |
| # Add other model providers (Anthropic, Groq, etc.) here if needed | |
| else: | |
| logger.error(f"Unsupported model provider for model name: {model_name}") | |
| raise ValueError(f"Model '{model_name}' is not supported or configuration is missing.") | |
| def invoke_llm(var, parameters): | |
| try: | |
| return var.invoke(parameters) | |
| except Exception as e: | |
| # Try to extract retry_delay seconds from the error message string | |
| match = re.search(r'retry_delay\s*{\s*seconds:\s*(\d+)', str(e)) | |
| if match: | |
| retry_seconds = int(match.group(1)) + 1 # Add 1 second buffer | |
| else: | |
| retry_seconds = 60 # fallback to 60 seconds if not found | |
| print(f"Error during .invoke : {e} \nwaiting {retry_seconds} seconds") | |
| time.sleep(retry_seconds) | |
| print("Waiting is finished") | |
| return var.invoke(parameters) | |
| # Example usage (could be called from other modules) | |
| # main_llm = get_llm(settings.main_llm_model) | |
| # eval_llm = get_llm(settings.eval_llm_model) |