import asyncio import time from contextlib import asynccontextmanager from typing import AsyncIterator import openai from openai import AsyncOpenAI from openai.types.chat import ChatCompletion from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential from vsp.llm.llm_service import LLMService, RateLimitError from vsp.shared import aws_clients, config, logger_factory logger = logger_factory.get_logger(__name__) class AsyncOpenRouterService(LLMService): """ An asynchronous service class for making calls to the OpenRouter API. This class handles authentication and implements a basic rate limiting strategy. """ def __init__(self, model: str, max_concurrency: int = 3, requests_per_minute: int = 60): """ Initialize the AsyncOpenRouterService. Args: model (str): The OpenRouter model to use for API calls. max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10. requests_per_minute (int): Maximum number of requests allowed per minute. Defaults to 60. """ self._client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self._fetch_api_key()) self._semaphore = asyncio.Semaphore(max_concurrency) self._model = model self._requests_per_minute = requests_per_minute self._request_times: list[float] = [] @staticmethod def _fetch_api_key() -> str: """ Fetch the OpenRouter API key from AWS Parameter Store. Returns: str: The OpenRouter API key. Raises: ValueError: If the API key is not found in the Parameter Store. RuntimeError: If there's an error accessing the Parameter Store. """ try: return aws_clients.fetch_from_parameter_store(config.get_openrouter_api_key_path(), is_secret=True) except aws_clients.ParameterNotFoundError as e: logger.error("OpenRouter API key not found in Parameter Store", error=str(e)) raise ValueError("OpenRouter API key not found") from e except aws_clients.ParameterStoreAccessError as e: logger.error("Error accessing Parameter Store", error=str(e)) raise RuntimeError("Unable to access OpenRouter API key") from e @asynccontextmanager async def __call__(self) -> AsyncIterator["AsyncOpenRouterService"]: try: yield self finally: await self._client.close() async def invoke( self, user_prompt: str | None = None, system_prompt: str | None = None, partial_assistant_prompt: str | None = None, max_tokens: int = 1000, temperature: float = 0.0, ) -> str | None: """ Invoke the OpenRouter API with the given prompts and parameters. This method handles rate limiting and makes the API call. Args: user_prompt (str | None): The main prompt from the user. system_prompt (str | None): A system message to set the context. partial_assistant_prompt (str | None): A partial response from the assistant. max_tokens (int): Maximum number of tokens in the response. temperature (float): Sampling temperature for response generation. Returns: str | None: The generated response from the OpenRouter API, or None if no response. Raises: RateLimitError: If the API rate limit is exceeded. openai.APIError: For any other errors encountered during the API call. """ async with self._semaphore: # Use semaphore to limit concurrency await self._wait_for_rate_limit() messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) if user_prompt: messages.append({"role": "user", "content": user_prompt}) if partial_assistant_prompt: messages.append({"role": "assistant", "content": partial_assistant_prompt}) response = await self.query_openrouter(max_tokens, temperature, messages) logger.info("OpenRouter API called", model=self._model) self._update_request_times() message = response.choices[0].message text = str(message.content) if text is None: logger.warn("No message content from OpenRouter API") return None if partial_assistant_prompt: text = f"{partial_assistant_prompt}{text}" # Extract token usage information usage = response.usage input_tokens = usage.prompt_tokens if usage else 0 output_tokens = usage.completion_tokens if usage else 0 # Log token usage logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens) return text @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_exception_type(RateLimitError), ) # type: ignore async def query_openrouter(self, max_tokens, temperature, messages) -> ChatCompletion: try: response = await self._client.chat.completions.create( model=self._model, messages=messages, max_tokens=max_tokens, temperature=temperature, ) return response except openai.RateLimitError as e: logger.warning("Rate limit error encountered. Retrying...") raise RateLimitError("OpenRouter API rate limit exceeded") from e except openai.APIError as e: logger.error("OpenRouter API error", error=str(e)) raise def _update_request_times(self) -> None: """ Update the list of request times, removing any that are older than one minute. """ current_time = time.time() self._request_times = [t for t in self._request_times if current_time - t < 60] self._request_times.append(current_time) async def _wait_for_rate_limit(self) -> None: """ Wait if necessary to respect the rate limit. """ while len(self._request_times) >= self._requests_per_minute: current_time = time.time() oldest_request_time = self._request_times[0] if current_time - oldest_request_time < 60: wait_time = 60 - (current_time - oldest_request_time) logger.info(f"Rate limit reached. Waiting for {wait_time:.2f} seconds.") await asyncio.sleep(wait_time) self._request_times = [t for t in self._request_times if current_time - t < 60] logger.debug(f"Requests in the last minute: {len(self._request_times)}")