navkast
LLM test harness + prompt caching + prompt tuning (#8)
49b13c6 unverified
"""
openai.py
This module provides an asynchronous service for interacting with OpenAI's API.
It handles authentication, rate limiting, and retrying of requests in case of rate limit errors.
The module includes:
1. A custom RateLimitException for handling rate limit errors.
2. An AsyncOpenAIService class for making API calls to OpenAI.
Usage:
async with AsyncOpenAIService(OpenAIModel.GPT_4_O) as service:
response = await service.invoke_openai(
user_prompt="Tell me a joke",
max_tokens=100,
temperature=0.7
)
print(response)
"""
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterator
import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
from vsp.llm.llm_service import LLMService, RateLimitError
from vsp.llm.openai.openai_model import OpenAIModel, get_openai_model_rate_limit
from vsp.llm.openai.openai_rate_limiter import OpenAIRateLimiter
from vsp.shared import aws_clients, config, logger_factory
logger = logger_factory.get_logger(__name__)
class AsyncOpenAIService(LLMService):
"""
An asynchronous service class for making calls to the OpenAI API.
This class handles authentication, rate limiting, and retrying of requests
when interacting with OpenAI's language models.
"""
def __init__(self, model: OpenAIModel, max_concurrency: int = 30):
"""
Initialize the AsyncOpenAIService.
Args:
model (OpenAIModel): The OpenAI model to use for API calls.
max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10.
"""
self._client = AsyncOpenAI(api_key=self._fetch_api_key())
self._semaphore = asyncio.Semaphore(max_concurrency)
self._model = model
rate_limit = get_openai_model_rate_limit(model)
self._rate_limiter = OpenAIRateLimiter(
initial_rate_requests=int(max(rate_limit.requests_per_minute * 0.95, 1)),
initial_rate_tokens=int(max(rate_limit.tokens_per_minute * 0.95, 100)),
per=60.0,
)
@staticmethod
def _fetch_api_key() -> str:
"""
Fetch the OpenAI API key from AWS Parameter Store.
Returns:
str: The OpenAI API key.
Raises:
ValueError: If the API key is not found in the Parameter Store.
RuntimeError: If there's an error accessing the Parameter Store.
"""
try:
return aws_clients.fetch_from_parameter_store(config.get_openai_api_key_path(), is_secret=True)
except aws_clients.ParameterNotFoundError as e:
logger.error("OpenAI API key not found in Parameter Store", error=str(e))
raise ValueError("OpenAI API key not found") from e
except aws_clients.ParameterStoreAccessError as e:
logger.error("Error accessing Parameter Store", error=str(e))
raise RuntimeError("Unable to access OpenAI API key") from e
@asynccontextmanager
async def __call__(self) -> AsyncIterator["AsyncOpenAIService"]:
try:
yield self
finally:
await self._client.close()
async def invoke(
self,
user_prompt: str | None = None,
system_prompt: str | None = None,
partial_assistant_prompt: str | None = None,
max_tokens: int = 1000,
temperature: float = 0.0,
) -> str | None:
"""
Invoke the OpenAI API with the given prompts and parameters.
This method handles rate limiting and makes the API call.
Args:
user_prompt (str | None): The main prompt from the user.
system_prompt (str | None): A system message to set the context.
partial_assistant_prompt (str | None): A partial response from the assistant.
max_tokens (int): Maximum number of tokens in the response.
temperature (float): Sampling temperature for response generation.
Returns:
str | None: The generated response from the OpenAI API, or None if no response.
Raises:
RateLimitError: If the API rate limit is exceeded.
openai.APIError: For any other errors encountered during the API call.
"""
async with self._semaphore: # Use semaphore to limit concurrency
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if user_prompt:
messages.append({"role": "user", "content": user_prompt})
if partial_assistant_prompt:
messages.append({"role": "assistant", "content": partial_assistant_prompt})
# Estimate token usage for rate limiting
estimated_tokens = sum(len(m["content"].split()) for m in messages) + max_tokens
await self._rate_limiter.acquire(estimated_tokens)
response = await self.query_openai(max_tokens, temperature, messages)
logger.info("OpenAI API called", model=self._model.value)
message = response.choices[0].message
text = str(message.content)
if message.refusal:
logger.error("OpenAI refusal message", refusal=message.refusal)
if text is None:
logger.warn("No message content from OpenAI API")
return None
if partial_assistant_prompt:
text = f"{partial_assistant_prompt}{text}"
# Extract token usage information
usage = response.usage
input_tokens = usage.prompt_tokens if usage else 0
output_tokens = usage.completion_tokens if usage else 0
# Log token usage
logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens)
return text
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type(RateLimitError),
) # type: ignore
async def query_openai(self, max_tokens, temperature, messages) -> ChatCompletion:
"""
Make an API call to OpenAI with retry logic for rate limit errors.
This method is decorated with a retry mechanism that will attempt to retry
the call up to 6 times with exponential backoff if a RateLimitError is raised.
Args:
max_tokens (int): Maximum number of tokens in the response.
temperature (float): Sampling temperature for response generation.
messages (list): List of message dictionaries to send to the API.
Returns:
ChatCompletion: The response from the OpenAI API.
Raises:
RateLimitError: If the API rate limit is exceeded after all retry attempts.
openai.APIError: For any other errors encountered during the API call.
"""
try:
response = await self._client.chat.completions.with_raw_response.create(
model=self._model.value,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
# Update rate limit info based on response headers
self._rate_limiter.update_rate_limit_info(response.headers)
return response.parse()
except openai.RateLimitError as e:
logger.warning("Rate limit error encountered. Retrying...")
raise RateLimitError("OpenAI API rate limit exceeded") from e
except openai.APIError as e:
logger.error("OpenAI API error", error=str(e))
raise