|
|
""" |
|
|
openai.py |
|
|
|
|
|
This module provides an asynchronous service for interacting with OpenAI's API. |
|
|
It handles authentication, rate limiting, and retrying of requests in case of rate limit errors. |
|
|
|
|
|
The module includes: |
|
|
1. A custom RateLimitException for handling rate limit errors. |
|
|
2. An AsyncOpenAIService class for making API calls to OpenAI. |
|
|
|
|
|
Usage: |
|
|
async with AsyncOpenAIService(OpenAIModel.GPT_4_O) as service: |
|
|
response = await service.invoke_openai( |
|
|
user_prompt="Tell me a joke", |
|
|
max_tokens=100, |
|
|
temperature=0.7 |
|
|
) |
|
|
print(response) |
|
|
|
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
from contextlib import asynccontextmanager |
|
|
from typing import AsyncIterator |
|
|
|
|
|
import openai |
|
|
from openai import AsyncOpenAI |
|
|
from openai.types.chat import ChatCompletion |
|
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential |
|
|
|
|
|
from vsp.llm.llm_service import LLMService, RateLimitError |
|
|
from vsp.llm.openai.openai_model import OpenAIModel, get_openai_model_rate_limit |
|
|
from vsp.llm.openai.openai_rate_limiter import OpenAIRateLimiter |
|
|
from vsp.shared import aws_clients, config, logger_factory |
|
|
|
|
|
logger = logger_factory.get_logger(__name__) |
|
|
|
|
|
|
|
|
class AsyncOpenAIService(LLMService): |
|
|
""" |
|
|
An asynchronous service class for making calls to the OpenAI API. |
|
|
|
|
|
This class handles authentication, rate limiting, and retrying of requests |
|
|
when interacting with OpenAI's language models. |
|
|
""" |
|
|
|
|
|
def __init__(self, model: OpenAIModel, max_concurrency: int = 30): |
|
|
""" |
|
|
Initialize the AsyncOpenAIService. |
|
|
|
|
|
Args: |
|
|
model (OpenAIModel): The OpenAI model to use for API calls. |
|
|
max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10. |
|
|
""" |
|
|
self._client = AsyncOpenAI(api_key=self._fetch_api_key()) |
|
|
self._semaphore = asyncio.Semaphore(max_concurrency) |
|
|
self._model = model |
|
|
rate_limit = get_openai_model_rate_limit(model) |
|
|
self._rate_limiter = OpenAIRateLimiter( |
|
|
initial_rate_requests=int(max(rate_limit.requests_per_minute * 0.95, 1)), |
|
|
initial_rate_tokens=int(max(rate_limit.tokens_per_minute * 0.95, 100)), |
|
|
per=60.0, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _fetch_api_key() -> str: |
|
|
""" |
|
|
Fetch the OpenAI API key from AWS Parameter Store. |
|
|
|
|
|
Returns: |
|
|
str: The OpenAI API key. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the API key is not found in the Parameter Store. |
|
|
RuntimeError: If there's an error accessing the Parameter Store. |
|
|
""" |
|
|
try: |
|
|
return aws_clients.fetch_from_parameter_store(config.get_openai_api_key_path(), is_secret=True) |
|
|
except aws_clients.ParameterNotFoundError as e: |
|
|
logger.error("OpenAI API key not found in Parameter Store", error=str(e)) |
|
|
raise ValueError("OpenAI API key not found") from e |
|
|
except aws_clients.ParameterStoreAccessError as e: |
|
|
logger.error("Error accessing Parameter Store", error=str(e)) |
|
|
raise RuntimeError("Unable to access OpenAI API key") from e |
|
|
|
|
|
@asynccontextmanager |
|
|
async def __call__(self) -> AsyncIterator["AsyncOpenAIService"]: |
|
|
try: |
|
|
yield self |
|
|
finally: |
|
|
await self._client.close() |
|
|
|
|
|
async def invoke( |
|
|
self, |
|
|
user_prompt: str | None = None, |
|
|
system_prompt: str | None = None, |
|
|
partial_assistant_prompt: str | None = None, |
|
|
max_tokens: int = 1000, |
|
|
temperature: float = 0.0, |
|
|
) -> str | None: |
|
|
""" |
|
|
Invoke the OpenAI API with the given prompts and parameters. |
|
|
|
|
|
This method handles rate limiting and makes the API call. |
|
|
|
|
|
Args: |
|
|
user_prompt (str | None): The main prompt from the user. |
|
|
system_prompt (str | None): A system message to set the context. |
|
|
partial_assistant_prompt (str | None): A partial response from the assistant. |
|
|
max_tokens (int): Maximum number of tokens in the response. |
|
|
temperature (float): Sampling temperature for response generation. |
|
|
|
|
|
Returns: |
|
|
str | None: The generated response from the OpenAI API, or None if no response. |
|
|
|
|
|
Raises: |
|
|
RateLimitError: If the API rate limit is exceeded. |
|
|
openai.APIError: For any other errors encountered during the API call. |
|
|
""" |
|
|
async with self._semaphore: |
|
|
messages = [] |
|
|
if system_prompt: |
|
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
if user_prompt: |
|
|
messages.append({"role": "user", "content": user_prompt}) |
|
|
if partial_assistant_prompt: |
|
|
messages.append({"role": "assistant", "content": partial_assistant_prompt}) |
|
|
|
|
|
|
|
|
estimated_tokens = sum(len(m["content"].split()) for m in messages) + max_tokens |
|
|
|
|
|
await self._rate_limiter.acquire(estimated_tokens) |
|
|
|
|
|
response = await self.query_openai(max_tokens, temperature, messages) |
|
|
logger.info("OpenAI API called", model=self._model.value) |
|
|
|
|
|
message = response.choices[0].message |
|
|
text = str(message.content) |
|
|
if message.refusal: |
|
|
logger.error("OpenAI refusal message", refusal=message.refusal) |
|
|
if text is None: |
|
|
logger.warn("No message content from OpenAI API") |
|
|
return None |
|
|
|
|
|
if partial_assistant_prompt: |
|
|
text = f"{partial_assistant_prompt}{text}" |
|
|
|
|
|
|
|
|
usage = response.usage |
|
|
input_tokens = usage.prompt_tokens if usage else 0 |
|
|
output_tokens = usage.completion_tokens if usage else 0 |
|
|
|
|
|
|
|
|
logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens) |
|
|
|
|
|
return text |
|
|
|
|
|
@retry( |
|
|
wait=wait_random_exponential(min=1, max=60), |
|
|
stop=stop_after_attempt(6), |
|
|
retry=retry_if_exception_type(RateLimitError), |
|
|
) |
|
|
async def query_openai(self, max_tokens, temperature, messages) -> ChatCompletion: |
|
|
""" |
|
|
Make an API call to OpenAI with retry logic for rate limit errors. |
|
|
|
|
|
This method is decorated with a retry mechanism that will attempt to retry |
|
|
the call up to 6 times with exponential backoff if a RateLimitError is raised. |
|
|
|
|
|
Args: |
|
|
max_tokens (int): Maximum number of tokens in the response. |
|
|
temperature (float): Sampling temperature for response generation. |
|
|
messages (list): List of message dictionaries to send to the API. |
|
|
|
|
|
Returns: |
|
|
ChatCompletion: The response from the OpenAI API. |
|
|
|
|
|
Raises: |
|
|
RateLimitError: If the API rate limit is exceeded after all retry attempts. |
|
|
openai.APIError: For any other errors encountered during the API call. |
|
|
""" |
|
|
try: |
|
|
response = await self._client.chat.completions.with_raw_response.create( |
|
|
model=self._model.value, |
|
|
messages=messages, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
) |
|
|
|
|
|
|
|
|
self._rate_limiter.update_rate_limit_info(response.headers) |
|
|
|
|
|
return response.parse() |
|
|
|
|
|
except openai.RateLimitError as e: |
|
|
logger.warning("Rate limit error encountered. Retrying...") |
|
|
raise RateLimitError("OpenAI API rate limit exceeded") from e |
|
|
except openai.APIError as e: |
|
|
logger.error("OpenAI API error", error=str(e)) |
|
|
raise |
|
|
|