""" openai.py This module provides an asynchronous service for interacting with OpenAI's API. It handles authentication, rate limiting, and retrying of requests in case of rate limit errors. The module includes: 1. A custom RateLimitException for handling rate limit errors. 2. An AsyncOpenAIService class for making API calls to OpenAI. Usage: async with AsyncOpenAIService(OpenAIModel.GPT_4_O) as service: response = await service.invoke_openai( user_prompt="Tell me a joke", max_tokens=100, temperature=0.7 ) print(response) """ import asyncio from contextlib import asynccontextmanager from typing import AsyncIterator import openai from openai import AsyncOpenAI from openai.types.chat import ChatCompletion from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential from vsp.llm.llm_service import LLMService, RateLimitError from vsp.llm.openai.openai_model import OpenAIModel, get_openai_model_rate_limit from vsp.llm.openai.openai_rate_limiter import OpenAIRateLimiter from vsp.shared import aws_clients, config, logger_factory logger = logger_factory.get_logger(__name__) class AsyncOpenAIService(LLMService): """ An asynchronous service class for making calls to the OpenAI API. This class handles authentication, rate limiting, and retrying of requests when interacting with OpenAI's language models. """ def __init__(self, model: OpenAIModel, max_concurrency: int = 30): """ Initialize the AsyncOpenAIService. Args: model (OpenAIModel): The OpenAI model to use for API calls. max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10. """ self._client = AsyncOpenAI(api_key=self._fetch_api_key()) self._semaphore = asyncio.Semaphore(max_concurrency) self._model = model rate_limit = get_openai_model_rate_limit(model) self._rate_limiter = OpenAIRateLimiter( initial_rate_requests=int(max(rate_limit.requests_per_minute * 0.95, 1)), initial_rate_tokens=int(max(rate_limit.tokens_per_minute * 0.95, 100)), per=60.0, ) @staticmethod def _fetch_api_key() -> str: """ Fetch the OpenAI API key from AWS Parameter Store. Returns: str: The OpenAI API key. Raises: ValueError: If the API key is not found in the Parameter Store. RuntimeError: If there's an error accessing the Parameter Store. """ try: return aws_clients.fetch_from_parameter_store(config.get_openai_api_key_path(), is_secret=True) except aws_clients.ParameterNotFoundError as e: logger.error("OpenAI API key not found in Parameter Store", error=str(e)) raise ValueError("OpenAI API key not found") from e except aws_clients.ParameterStoreAccessError as e: logger.error("Error accessing Parameter Store", error=str(e)) raise RuntimeError("Unable to access OpenAI API key") from e @asynccontextmanager async def __call__(self) -> AsyncIterator["AsyncOpenAIService"]: try: yield self finally: await self._client.close() async def invoke( self, user_prompt: str | None = None, system_prompt: str | None = None, partial_assistant_prompt: str | None = None, max_tokens: int = 1000, temperature: float = 0.0, ) -> str | None: """ Invoke the OpenAI API with the given prompts and parameters. This method handles rate limiting and makes the API call. Args: user_prompt (str | None): The main prompt from the user. system_prompt (str | None): A system message to set the context. partial_assistant_prompt (str | None): A partial response from the assistant. max_tokens (int): Maximum number of tokens in the response. temperature (float): Sampling temperature for response generation. Returns: str | None: The generated response from the OpenAI API, or None if no response. Raises: RateLimitError: If the API rate limit is exceeded. openai.APIError: For any other errors encountered during the API call. """ async with self._semaphore: # Use semaphore to limit concurrency messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) if user_prompt: messages.append({"role": "user", "content": user_prompt}) if partial_assistant_prompt: messages.append({"role": "assistant", "content": partial_assistant_prompt}) # Estimate token usage for rate limiting estimated_tokens = sum(len(m["content"].split()) for m in messages) + max_tokens await self._rate_limiter.acquire(estimated_tokens) response = await self.query_openai(max_tokens, temperature, messages) logger.info("OpenAI API called", model=self._model.value) message = response.choices[0].message text = str(message.content) if message.refusal: logger.error("OpenAI refusal message", refusal=message.refusal) if text is None: logger.warn("No message content from OpenAI API") return None if partial_assistant_prompt: text = f"{partial_assistant_prompt}{text}" # Extract token usage information usage = response.usage input_tokens = usage.prompt_tokens if usage else 0 output_tokens = usage.completion_tokens if usage else 0 # Log token usage logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens) return text @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_exception_type(RateLimitError), ) # type: ignore async def query_openai(self, max_tokens, temperature, messages) -> ChatCompletion: """ Make an API call to OpenAI with retry logic for rate limit errors. This method is decorated with a retry mechanism that will attempt to retry the call up to 6 times with exponential backoff if a RateLimitError is raised. Args: max_tokens (int): Maximum number of tokens in the response. temperature (float): Sampling temperature for response generation. messages (list): List of message dictionaries to send to the API. Returns: ChatCompletion: The response from the OpenAI API. Raises: RateLimitError: If the API rate limit is exceeded after all retry attempts. openai.APIError: For any other errors encountered during the API call. """ try: response = await self._client.chat.completions.with_raw_response.create( model=self._model.value, messages=messages, max_tokens=max_tokens, temperature=temperature, ) # Update rate limit info based on response headers self._rate_limiter.update_rate_limit_info(response.headers) return response.parse() except openai.RateLimitError as e: logger.warning("Rate limit error encountered. Retrying...") raise RateLimitError("OpenAI API rate limit exceeded") from e except openai.APIError as e: logger.error("OpenAI API error", error=str(e)) raise