File size: 6,907 Bytes
a1ead4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import asyncio
import time
from contextlib import asynccontextmanager
from typing import AsyncIterator

import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential

from vsp.llm.llm_service import LLMService, RateLimitError
from vsp.shared import aws_clients, config, logger_factory

logger = logger_factory.get_logger(__name__)


class AsyncOpenRouterService(LLMService):
    """
    An asynchronous service class for making calls to the OpenRouter API.

    This class handles authentication and implements a basic rate limiting strategy.
    """

    def __init__(self, model: str, max_concurrency: int = 3, requests_per_minute: int = 60):
        """
        Initialize the AsyncOpenRouterService.

        Args:
            model (str): The OpenRouter model to use for API calls.
            max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10.
            requests_per_minute (int): Maximum number of requests allowed per minute. Defaults to 60.
        """
        self._client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self._fetch_api_key())
        self._semaphore = asyncio.Semaphore(max_concurrency)
        self._model = model
        self._requests_per_minute = requests_per_minute
        self._request_times: list[float] = []

    @staticmethod
    def _fetch_api_key() -> str:
        """
        Fetch the OpenRouter API key from AWS Parameter Store.

        Returns:
            str: The OpenRouter API key.

        Raises:
            ValueError: If the API key is not found in the Parameter Store.
            RuntimeError: If there's an error accessing the Parameter Store.
        """
        try:
            return aws_clients.fetch_from_parameter_store(config.get_openrouter_api_key_path(), is_secret=True)
        except aws_clients.ParameterNotFoundError as e:
            logger.error("OpenRouter API key not found in Parameter Store", error=str(e))
            raise ValueError("OpenRouter API key not found") from e
        except aws_clients.ParameterStoreAccessError as e:
            logger.error("Error accessing Parameter Store", error=str(e))
            raise RuntimeError("Unable to access OpenRouter API key") from e

    @asynccontextmanager
    async def __call__(self) -> AsyncIterator["AsyncOpenRouterService"]:
        try:
            yield self
        finally:
            await self._client.close()

    async def invoke(
        self,
        user_prompt: str | None = None,
        system_prompt: str | None = None,
        partial_assistant_prompt: str | None = None,
        max_tokens: int = 1000,
        temperature: float = 0.0,
    ) -> str | None:
        """
        Invoke the OpenRouter API with the given prompts and parameters.

        This method handles rate limiting and makes the API call.

        Args:
            user_prompt (str | None): The main prompt from the user.
            system_prompt (str | None): A system message to set the context.
            partial_assistant_prompt (str | None): A partial response from the assistant.
            max_tokens (int): Maximum number of tokens in the response.
            temperature (float): Sampling temperature for response generation.

        Returns:
            str | None: The generated response from the OpenRouter API, or None if no response.

        Raises:
            RateLimitError: If the API rate limit is exceeded.
            openai.APIError: For any other errors encountered during the API call.
        """
        async with self._semaphore:  # Use semaphore to limit concurrency
            await self._wait_for_rate_limit()

            messages = []
            if system_prompt:
                messages.append({"role": "system", "content": system_prompt})
            if user_prompt:
                messages.append({"role": "user", "content": user_prompt})
            if partial_assistant_prompt:
                messages.append({"role": "assistant", "content": partial_assistant_prompt})

            response = await self.query_openrouter(max_tokens, temperature, messages)
            logger.info("OpenRouter API called", model=self._model)

            self._update_request_times()

            message = response.choices[0].message
            text = str(message.content)
            if text is None:
                logger.warn("No message content from OpenRouter API")
                return None

            if partial_assistant_prompt:
                text = f"{partial_assistant_prompt}{text}"

            # Extract token usage information
            usage = response.usage
            input_tokens = usage.prompt_tokens if usage else 0
            output_tokens = usage.completion_tokens if usage else 0

            # Log token usage
            logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens)

            return text

    @retry(
        wait=wait_random_exponential(min=1, max=60),
        stop=stop_after_attempt(6),
        retry=retry_if_exception_type(RateLimitError),
    )  # type: ignore
    async def query_openrouter(self, max_tokens, temperature, messages) -> ChatCompletion:
        try:
            response = await self._client.chat.completions.create(
                model=self._model,
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
            )

            return response

        except openai.RateLimitError as e:
            logger.warning("Rate limit error encountered. Retrying...")
            raise RateLimitError("OpenRouter API rate limit exceeded") from e
        except openai.APIError as e:
            logger.error("OpenRouter API error", error=str(e))
            raise

    def _update_request_times(self) -> None:
        """
        Update the list of request times, removing any that are older than one minute.
        """
        current_time = time.time()
        self._request_times = [t for t in self._request_times if current_time - t < 60]
        self._request_times.append(current_time)

    async def _wait_for_rate_limit(self) -> None:
        """
        Wait if necessary to respect the rate limit.
        """
        while len(self._request_times) >= self._requests_per_minute:
            current_time = time.time()
            oldest_request_time = self._request_times[0]
            if current_time - oldest_request_time < 60:
                wait_time = 60 - (current_time - oldest_request_time)
                logger.info(f"Rate limit reached. Waiting for {wait_time:.2f} seconds.")
                await asyncio.sleep(wait_time)
            self._request_times = [t for t in self._request_times if current_time - t < 60]
        logger.debug(f"Requests in the last minute: {len(self._request_times)}")