File size: 7,792 Bytes
3b993c4 24d33b9 3b993c4 49b13c6 3b993c4 dce8143 3b993c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
"""
openai.py
This module provides an asynchronous service for interacting with OpenAI's API.
It handles authentication, rate limiting, and retrying of requests in case of rate limit errors.
The module includes:
1. A custom RateLimitException for handling rate limit errors.
2. An AsyncOpenAIService class for making API calls to OpenAI.
Usage:
async with AsyncOpenAIService(OpenAIModel.GPT_4_O) as service:
response = await service.invoke_openai(
user_prompt="Tell me a joke",
max_tokens=100,
temperature=0.7
)
print(response)
"""
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterator
import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
from vsp.llm.llm_service import LLMService, RateLimitError
from vsp.llm.openai.openai_model import OpenAIModel, get_openai_model_rate_limit
from vsp.llm.openai.openai_rate_limiter import OpenAIRateLimiter
from vsp.shared import aws_clients, config, logger_factory
logger = logger_factory.get_logger(__name__)
class AsyncOpenAIService(LLMService):
"""
An asynchronous service class for making calls to the OpenAI API.
This class handles authentication, rate limiting, and retrying of requests
when interacting with OpenAI's language models.
"""
def __init__(self, model: OpenAIModel, max_concurrency: int = 30):
"""
Initialize the AsyncOpenAIService.
Args:
model (OpenAIModel): The OpenAI model to use for API calls.
max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10.
"""
self._client = AsyncOpenAI(api_key=self._fetch_api_key())
self._semaphore = asyncio.Semaphore(max_concurrency)
self._model = model
rate_limit = get_openai_model_rate_limit(model)
self._rate_limiter = OpenAIRateLimiter(
initial_rate_requests=int(max(rate_limit.requests_per_minute * 0.95, 1)),
initial_rate_tokens=int(max(rate_limit.tokens_per_minute * 0.95, 100)),
per=60.0,
)
@staticmethod
def _fetch_api_key() -> str:
"""
Fetch the OpenAI API key from AWS Parameter Store.
Returns:
str: The OpenAI API key.
Raises:
ValueError: If the API key is not found in the Parameter Store.
RuntimeError: If there's an error accessing the Parameter Store.
"""
try:
return aws_clients.fetch_from_parameter_store(config.get_openai_api_key_path(), is_secret=True)
except aws_clients.ParameterNotFoundError as e:
logger.error("OpenAI API key not found in Parameter Store", error=str(e))
raise ValueError("OpenAI API key not found") from e
except aws_clients.ParameterStoreAccessError as e:
logger.error("Error accessing Parameter Store", error=str(e))
raise RuntimeError("Unable to access OpenAI API key") from e
@asynccontextmanager
async def __call__(self) -> AsyncIterator["AsyncOpenAIService"]:
try:
yield self
finally:
await self._client.close()
async def invoke(
self,
user_prompt: str | None = None,
system_prompt: str | None = None,
partial_assistant_prompt: str | None = None,
max_tokens: int = 1000,
temperature: float = 0.0,
) -> str | None:
"""
Invoke the OpenAI API with the given prompts and parameters.
This method handles rate limiting and makes the API call.
Args:
user_prompt (str | None): The main prompt from the user.
system_prompt (str | None): A system message to set the context.
partial_assistant_prompt (str | None): A partial response from the assistant.
max_tokens (int): Maximum number of tokens in the response.
temperature (float): Sampling temperature for response generation.
Returns:
str | None: The generated response from the OpenAI API, or None if no response.
Raises:
RateLimitError: If the API rate limit is exceeded.
openai.APIError: For any other errors encountered during the API call.
"""
async with self._semaphore: # Use semaphore to limit concurrency
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if user_prompt:
messages.append({"role": "user", "content": user_prompt})
if partial_assistant_prompt:
messages.append({"role": "assistant", "content": partial_assistant_prompt})
# Estimate token usage for rate limiting
estimated_tokens = sum(len(m["content"].split()) for m in messages) + max_tokens
await self._rate_limiter.acquire(estimated_tokens)
response = await self.query_openai(max_tokens, temperature, messages)
logger.info("OpenAI API called", model=self._model.value)
message = response.choices[0].message
text = str(message.content)
if message.refusal:
logger.error("OpenAI refusal message", refusal=message.refusal)
if text is None:
logger.warn("No message content from OpenAI API")
return None
if partial_assistant_prompt:
text = f"{partial_assistant_prompt}{text}"
# Extract token usage information
usage = response.usage
input_tokens = usage.prompt_tokens if usage else 0
output_tokens = usage.completion_tokens if usage else 0
# Log token usage
logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens)
return text
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type(RateLimitError),
) # type: ignore
async def query_openai(self, max_tokens, temperature, messages) -> ChatCompletion:
"""
Make an API call to OpenAI with retry logic for rate limit errors.
This method is decorated with a retry mechanism that will attempt to retry
the call up to 6 times with exponential backoff if a RateLimitError is raised.
Args:
max_tokens (int): Maximum number of tokens in the response.
temperature (float): Sampling temperature for response generation.
messages (list): List of message dictionaries to send to the API.
Returns:
ChatCompletion: The response from the OpenAI API.
Raises:
RateLimitError: If the API rate limit is exceeded after all retry attempts.
openai.APIError: For any other errors encountered during the API call.
"""
try:
response = await self._client.chat.completions.with_raw_response.create(
model=self._model.value,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
# Update rate limit info based on response headers
self._rate_limiter.update_rate_limit_info(response.headers)
return response.parse()
except openai.RateLimitError as e:
logger.warning("Rate limit error encountered. Retrying...")
raise RateLimitError("OpenAI API rate limit exceeded") from e
except openai.APIError as e:
logger.error("OpenAI API error", error=str(e))
raise
|