File size: 6,907 Bytes
a1ead4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import asyncio
import time
from contextlib import asynccontextmanager
from typing import AsyncIterator
import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
from vsp.llm.llm_service import LLMService, RateLimitError
from vsp.shared import aws_clients, config, logger_factory
logger = logger_factory.get_logger(__name__)
class AsyncOpenRouterService(LLMService):
"""
An asynchronous service class for making calls to the OpenRouter API.
This class handles authentication and implements a basic rate limiting strategy.
"""
def __init__(self, model: str, max_concurrency: int = 3, requests_per_minute: int = 60):
"""
Initialize the AsyncOpenRouterService.
Args:
model (str): The OpenRouter model to use for API calls.
max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10.
requests_per_minute (int): Maximum number of requests allowed per minute. Defaults to 60.
"""
self._client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self._fetch_api_key())
self._semaphore = asyncio.Semaphore(max_concurrency)
self._model = model
self._requests_per_minute = requests_per_minute
self._request_times: list[float] = []
@staticmethod
def _fetch_api_key() -> str:
"""
Fetch the OpenRouter API key from AWS Parameter Store.
Returns:
str: The OpenRouter API key.
Raises:
ValueError: If the API key is not found in the Parameter Store.
RuntimeError: If there's an error accessing the Parameter Store.
"""
try:
return aws_clients.fetch_from_parameter_store(config.get_openrouter_api_key_path(), is_secret=True)
except aws_clients.ParameterNotFoundError as e:
logger.error("OpenRouter API key not found in Parameter Store", error=str(e))
raise ValueError("OpenRouter API key not found") from e
except aws_clients.ParameterStoreAccessError as e:
logger.error("Error accessing Parameter Store", error=str(e))
raise RuntimeError("Unable to access OpenRouter API key") from e
@asynccontextmanager
async def __call__(self) -> AsyncIterator["AsyncOpenRouterService"]:
try:
yield self
finally:
await self._client.close()
async def invoke(
self,
user_prompt: str | None = None,
system_prompt: str | None = None,
partial_assistant_prompt: str | None = None,
max_tokens: int = 1000,
temperature: float = 0.0,
) -> str | None:
"""
Invoke the OpenRouter API with the given prompts and parameters.
This method handles rate limiting and makes the API call.
Args:
user_prompt (str | None): The main prompt from the user.
system_prompt (str | None): A system message to set the context.
partial_assistant_prompt (str | None): A partial response from the assistant.
max_tokens (int): Maximum number of tokens in the response.
temperature (float): Sampling temperature for response generation.
Returns:
str | None: The generated response from the OpenRouter API, or None if no response.
Raises:
RateLimitError: If the API rate limit is exceeded.
openai.APIError: For any other errors encountered during the API call.
"""
async with self._semaphore: # Use semaphore to limit concurrency
await self._wait_for_rate_limit()
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if user_prompt:
messages.append({"role": "user", "content": user_prompt})
if partial_assistant_prompt:
messages.append({"role": "assistant", "content": partial_assistant_prompt})
response = await self.query_openrouter(max_tokens, temperature, messages)
logger.info("OpenRouter API called", model=self._model)
self._update_request_times()
message = response.choices[0].message
text = str(message.content)
if text is None:
logger.warn("No message content from OpenRouter API")
return None
if partial_assistant_prompt:
text = f"{partial_assistant_prompt}{text}"
# Extract token usage information
usage = response.usage
input_tokens = usage.prompt_tokens if usage else 0
output_tokens = usage.completion_tokens if usage else 0
# Log token usage
logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens)
return text
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type(RateLimitError),
) # type: ignore
async def query_openrouter(self, max_tokens, temperature, messages) -> ChatCompletion:
try:
response = await self._client.chat.completions.create(
model=self._model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
return response
except openai.RateLimitError as e:
logger.warning("Rate limit error encountered. Retrying...")
raise RateLimitError("OpenRouter API rate limit exceeded") from e
except openai.APIError as e:
logger.error("OpenRouter API error", error=str(e))
raise
def _update_request_times(self) -> None:
"""
Update the list of request times, removing any that are older than one minute.
"""
current_time = time.time()
self._request_times = [t for t in self._request_times if current_time - t < 60]
self._request_times.append(current_time)
async def _wait_for_rate_limit(self) -> None:
"""
Wait if necessary to respect the rate limit.
"""
while len(self._request_times) >= self._requests_per_minute:
current_time = time.time()
oldest_request_time = self._request_times[0]
if current_time - oldest_request_time < 60:
wait_time = 60 - (current_time - oldest_request_time)
logger.info(f"Rate limit reached. Waiting for {wait_time:.2f} seconds.")
await asyncio.sleep(wait_time)
self._request_times = [t for t in self._request_times if current_time - t < 60]
logger.debug(f"Requests in the last minute: {len(self._request_times)}")
|