File size: 7,792 Bytes
3b993c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24d33b9
 
 
 
3b993c4
 
 
 
 
 
 
 
 
 
 
 
49b13c6
3b993c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dce8143
3b993c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""
openai.py

This module provides an asynchronous service for interacting with OpenAI's API.
It handles authentication, rate limiting, and retrying of requests in case of rate limit errors.

The module includes:
1. A custom RateLimitException for handling rate limit errors.
2. An AsyncOpenAIService class for making API calls to OpenAI.

Usage:
    async with AsyncOpenAIService(OpenAIModel.GPT_4_O) as service:
        response = await service.invoke_openai(
            user_prompt="Tell me a joke",
            max_tokens=100,
            temperature=0.7
        )
        print(response)

"""

import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterator

import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential

from vsp.llm.llm_service import LLMService, RateLimitError
from vsp.llm.openai.openai_model import OpenAIModel, get_openai_model_rate_limit
from vsp.llm.openai.openai_rate_limiter import OpenAIRateLimiter
from vsp.shared import aws_clients, config, logger_factory

logger = logger_factory.get_logger(__name__)


class AsyncOpenAIService(LLMService):
    """
    An asynchronous service class for making calls to the OpenAI API.

    This class handles authentication, rate limiting, and retrying of requests
    when interacting with OpenAI's language models.
    """

    def __init__(self, model: OpenAIModel, max_concurrency: int = 30):
        """
        Initialize the AsyncOpenAIService.

        Args:
            model (OpenAIModel): The OpenAI model to use for API calls.
            max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10.
        """
        self._client = AsyncOpenAI(api_key=self._fetch_api_key())
        self._semaphore = asyncio.Semaphore(max_concurrency)
        self._model = model
        rate_limit = get_openai_model_rate_limit(model)
        self._rate_limiter = OpenAIRateLimiter(
            initial_rate_requests=int(max(rate_limit.requests_per_minute * 0.95, 1)),
            initial_rate_tokens=int(max(rate_limit.tokens_per_minute * 0.95, 100)),
            per=60.0,
        )

    @staticmethod
    def _fetch_api_key() -> str:
        """
        Fetch the OpenAI API key from AWS Parameter Store.

        Returns:
            str: The OpenAI API key.

        Raises:
            ValueError: If the API key is not found in the Parameter Store.
            RuntimeError: If there's an error accessing the Parameter Store.
        """
        try:
            return aws_clients.fetch_from_parameter_store(config.get_openai_api_key_path(), is_secret=True)
        except aws_clients.ParameterNotFoundError as e:
            logger.error("OpenAI API key not found in Parameter Store", error=str(e))
            raise ValueError("OpenAI API key not found") from e
        except aws_clients.ParameterStoreAccessError as e:
            logger.error("Error accessing Parameter Store", error=str(e))
            raise RuntimeError("Unable to access OpenAI API key") from e

    @asynccontextmanager
    async def __call__(self) -> AsyncIterator["AsyncOpenAIService"]:
        try:
            yield self
        finally:
            await self._client.close()

    async def invoke(
        self,
        user_prompt: str | None = None,
        system_prompt: str | None = None,
        partial_assistant_prompt: str | None = None,
        max_tokens: int = 1000,
        temperature: float = 0.0,
    ) -> str | None:
        """
        Invoke the OpenAI API with the given prompts and parameters.

        This method handles rate limiting and makes the API call.

        Args:
            user_prompt (str | None): The main prompt from the user.
            system_prompt (str | None): A system message to set the context.
            partial_assistant_prompt (str | None): A partial response from the assistant.
            max_tokens (int): Maximum number of tokens in the response.
            temperature (float): Sampling temperature for response generation.

        Returns:
            str | None: The generated response from the OpenAI API, or None if no response.

        Raises:
            RateLimitError: If the API rate limit is exceeded.
            openai.APIError: For any other errors encountered during the API call.
        """
        async with self._semaphore:  # Use semaphore to limit concurrency
            messages = []
            if system_prompt:
                messages.append({"role": "system", "content": system_prompt})
            if user_prompt:
                messages.append({"role": "user", "content": user_prompt})
            if partial_assistant_prompt:
                messages.append({"role": "assistant", "content": partial_assistant_prompt})

            # Estimate token usage for rate limiting
            estimated_tokens = sum(len(m["content"].split()) for m in messages) + max_tokens

            await self._rate_limiter.acquire(estimated_tokens)

            response = await self.query_openai(max_tokens, temperature, messages)
            logger.info("OpenAI API called", model=self._model.value)

            message = response.choices[0].message
            text = str(message.content)
            if message.refusal:
                logger.error("OpenAI refusal message", refusal=message.refusal)
            if text is None:
                logger.warn("No message content from OpenAI API")
                return None

            if partial_assistant_prompt:
                text = f"{partial_assistant_prompt}{text}"

            # Extract token usage information
            usage = response.usage
            input_tokens = usage.prompt_tokens if usage else 0
            output_tokens = usage.completion_tokens if usage else 0

            # Log token usage
            logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens)

            return text

    @retry(
        wait=wait_random_exponential(min=1, max=60),
        stop=stop_after_attempt(6),
        retry=retry_if_exception_type(RateLimitError),
    )  # type: ignore
    async def query_openai(self, max_tokens, temperature, messages) -> ChatCompletion:
        """
        Make an API call to OpenAI with retry logic for rate limit errors.

        This method is decorated with a retry mechanism that will attempt to retry
        the call up to 6 times with exponential backoff if a RateLimitError is raised.

        Args:
            max_tokens (int): Maximum number of tokens in the response.
            temperature (float): Sampling temperature for response generation.
            messages (list): List of message dictionaries to send to the API.

        Returns:
            ChatCompletion: The response from the OpenAI API.

        Raises:
            RateLimitError: If the API rate limit is exceeded after all retry attempts.
            openai.APIError: For any other errors encountered during the API call.
        """
        try:
            response = await self._client.chat.completions.with_raw_response.create(
                model=self._model.value,
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
            )

            # Update rate limit info based on response headers
            self._rate_limiter.update_rate_limit_info(response.headers)

            return response.parse()

        except openai.RateLimitError as e:
            logger.warning("Rate limit error encountered. Retrying...")
            raise RateLimitError("OpenAI API rate limit exceeded") from e
        except openai.APIError as e:
            logger.error("OpenAI API error", error=str(e))
            raise