navkast
commited on
Create openrouter client (#7)
Browse files* feat: Implement AsyncOpenRouterService for OpenAI API integration
* fix: Add type annotations and ignore untyped function calls in openrouter.py
* Create openrouter client
src/vsp/llm/openrouter/openrouter.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import time
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
from typing import AsyncIterator
|
| 5 |
+
|
| 6 |
+
import openai
|
| 7 |
+
from openai import AsyncOpenAI
|
| 8 |
+
from openai.types.chat import ChatCompletion
|
| 9 |
+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
|
| 10 |
+
|
| 11 |
+
from vsp.llm.llm_service import LLMService, RateLimitError
|
| 12 |
+
from vsp.shared import aws_clients, config, logger_factory
|
| 13 |
+
|
| 14 |
+
logger = logger_factory.get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AsyncOpenRouterService(LLMService):
|
| 18 |
+
"""
|
| 19 |
+
An asynchronous service class for making calls to the OpenRouter API.
|
| 20 |
+
|
| 21 |
+
This class handles authentication and implements a basic rate limiting strategy.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, model: str, max_concurrency: int = 3, requests_per_minute: int = 60):
|
| 25 |
+
"""
|
| 26 |
+
Initialize the AsyncOpenRouterService.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
model (str): The OpenRouter model to use for API calls.
|
| 30 |
+
max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10.
|
| 31 |
+
requests_per_minute (int): Maximum number of requests allowed per minute. Defaults to 60.
|
| 32 |
+
"""
|
| 33 |
+
self._client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self._fetch_api_key())
|
| 34 |
+
self._semaphore = asyncio.Semaphore(max_concurrency)
|
| 35 |
+
self._model = model
|
| 36 |
+
self._requests_per_minute = requests_per_minute
|
| 37 |
+
self._request_times: list[float] = []
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def _fetch_api_key() -> str:
|
| 41 |
+
"""
|
| 42 |
+
Fetch the OpenRouter API key from AWS Parameter Store.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
str: The OpenRouter API key.
|
| 46 |
+
|
| 47 |
+
Raises:
|
| 48 |
+
ValueError: If the API key is not found in the Parameter Store.
|
| 49 |
+
RuntimeError: If there's an error accessing the Parameter Store.
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
return aws_clients.fetch_from_parameter_store(config.get_openrouter_api_key_path(), is_secret=True)
|
| 53 |
+
except aws_clients.ParameterNotFoundError as e:
|
| 54 |
+
logger.error("OpenRouter API key not found in Parameter Store", error=str(e))
|
| 55 |
+
raise ValueError("OpenRouter API key not found") from e
|
| 56 |
+
except aws_clients.ParameterStoreAccessError as e:
|
| 57 |
+
logger.error("Error accessing Parameter Store", error=str(e))
|
| 58 |
+
raise RuntimeError("Unable to access OpenRouter API key") from e
|
| 59 |
+
|
| 60 |
+
@asynccontextmanager
|
| 61 |
+
async def __call__(self) -> AsyncIterator["AsyncOpenRouterService"]:
|
| 62 |
+
try:
|
| 63 |
+
yield self
|
| 64 |
+
finally:
|
| 65 |
+
await self._client.close()
|
| 66 |
+
|
| 67 |
+
async def invoke(
|
| 68 |
+
self,
|
| 69 |
+
user_prompt: str | None = None,
|
| 70 |
+
system_prompt: str | None = None,
|
| 71 |
+
partial_assistant_prompt: str | None = None,
|
| 72 |
+
max_tokens: int = 1000,
|
| 73 |
+
temperature: float = 0.0,
|
| 74 |
+
) -> str | None:
|
| 75 |
+
"""
|
| 76 |
+
Invoke the OpenRouter API with the given prompts and parameters.
|
| 77 |
+
|
| 78 |
+
This method handles rate limiting and makes the API call.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
user_prompt (str | None): The main prompt from the user.
|
| 82 |
+
system_prompt (str | None): A system message to set the context.
|
| 83 |
+
partial_assistant_prompt (str | None): A partial response from the assistant.
|
| 84 |
+
max_tokens (int): Maximum number of tokens in the response.
|
| 85 |
+
temperature (float): Sampling temperature for response generation.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
str | None: The generated response from the OpenRouter API, or None if no response.
|
| 89 |
+
|
| 90 |
+
Raises:
|
| 91 |
+
RateLimitError: If the API rate limit is exceeded.
|
| 92 |
+
openai.APIError: For any other errors encountered during the API call.
|
| 93 |
+
"""
|
| 94 |
+
async with self._semaphore: # Use semaphore to limit concurrency
|
| 95 |
+
await self._wait_for_rate_limit()
|
| 96 |
+
|
| 97 |
+
messages = []
|
| 98 |
+
if system_prompt:
|
| 99 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 100 |
+
if user_prompt:
|
| 101 |
+
messages.append({"role": "user", "content": user_prompt})
|
| 102 |
+
if partial_assistant_prompt:
|
| 103 |
+
messages.append({"role": "assistant", "content": partial_assistant_prompt})
|
| 104 |
+
|
| 105 |
+
response = await self.query_openrouter(max_tokens, temperature, messages)
|
| 106 |
+
logger.info("OpenRouter API called", model=self._model)
|
| 107 |
+
|
| 108 |
+
self._update_request_times()
|
| 109 |
+
|
| 110 |
+
message = response.choices[0].message
|
| 111 |
+
text = str(message.content)
|
| 112 |
+
if text is None:
|
| 113 |
+
logger.warn("No message content from OpenRouter API")
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
if partial_assistant_prompt:
|
| 117 |
+
text = f"{partial_assistant_prompt}{text}"
|
| 118 |
+
|
| 119 |
+
# Extract token usage information
|
| 120 |
+
usage = response.usage
|
| 121 |
+
input_tokens = usage.prompt_tokens if usage else 0
|
| 122 |
+
output_tokens = usage.completion_tokens if usage else 0
|
| 123 |
+
|
| 124 |
+
# Log token usage
|
| 125 |
+
logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens)
|
| 126 |
+
|
| 127 |
+
return text
|
| 128 |
+
|
| 129 |
+
@retry(
|
| 130 |
+
wait=wait_random_exponential(min=1, max=60),
|
| 131 |
+
stop=stop_after_attempt(6),
|
| 132 |
+
retry=retry_if_exception_type(RateLimitError),
|
| 133 |
+
) # type: ignore
|
| 134 |
+
async def query_openrouter(self, max_tokens, temperature, messages) -> ChatCompletion:
|
| 135 |
+
try:
|
| 136 |
+
response = await self._client.chat.completions.create(
|
| 137 |
+
model=self._model,
|
| 138 |
+
messages=messages,
|
| 139 |
+
max_tokens=max_tokens,
|
| 140 |
+
temperature=temperature,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
return response
|
| 144 |
+
|
| 145 |
+
except openai.RateLimitError as e:
|
| 146 |
+
logger.warning("Rate limit error encountered. Retrying...")
|
| 147 |
+
raise RateLimitError("OpenRouter API rate limit exceeded") from e
|
| 148 |
+
except openai.APIError as e:
|
| 149 |
+
logger.error("OpenRouter API error", error=str(e))
|
| 150 |
+
raise
|
| 151 |
+
|
| 152 |
+
def _update_request_times(self) -> None:
|
| 153 |
+
"""
|
| 154 |
+
Update the list of request times, removing any that are older than one minute.
|
| 155 |
+
"""
|
| 156 |
+
current_time = time.time()
|
| 157 |
+
self._request_times = [t for t in self._request_times if current_time - t < 60]
|
| 158 |
+
self._request_times.append(current_time)
|
| 159 |
+
|
| 160 |
+
async def _wait_for_rate_limit(self) -> None:
|
| 161 |
+
"""
|
| 162 |
+
Wait if necessary to respect the rate limit.
|
| 163 |
+
"""
|
| 164 |
+
while len(self._request_times) >= self._requests_per_minute:
|
| 165 |
+
current_time = time.time()
|
| 166 |
+
oldest_request_time = self._request_times[0]
|
| 167 |
+
if current_time - oldest_request_time < 60:
|
| 168 |
+
wait_time = 60 - (current_time - oldest_request_time)
|
| 169 |
+
logger.info(f"Rate limit reached. Waiting for {wait_time:.2f} seconds.")
|
| 170 |
+
await asyncio.sleep(wait_time)
|
| 171 |
+
self._request_times = [t for t in self._request_times if current_time - t < 60]
|
| 172 |
+
logger.debug(f"Requests in the last minute: {len(self._request_times)}")
|
src/vsp/shared/config.py
CHANGED
|
@@ -93,6 +93,16 @@ def get_openai_api_key_path() -> str:
|
|
| 93 |
return str(config["openai"]["openai_api_key_parameter_store_path"])
|
| 94 |
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
@cache
|
| 97 |
def get_linkedin_key_path() -> str:
|
| 98 |
"""
|
|
|
|
| 93 |
return str(config["openai"]["openai_api_key_parameter_store_path"])
|
| 94 |
|
| 95 |
|
| 96 |
+
@cache
|
| 97 |
+
def get_openrouter_api_key_path() -> str:
|
| 98 |
+
"""
|
| 99 |
+
Reads the OpenRouter API key path from the TOML configuration file.
|
| 100 |
+
Key is in AWS parameter store
|
| 101 |
+
"""
|
| 102 |
+
config = _get_config()
|
| 103 |
+
return str(config["openrouter"]["openrouter_api_key_parameter_store_path"])
|
| 104 |
+
|
| 105 |
+
|
| 106 |
@cache
|
| 107 |
def get_linkedin_key_path() -> str:
|
| 108 |
"""
|
src/vsp/shared/config.toml
CHANGED
|
@@ -8,5 +8,8 @@ bedrock_aws_account = "339713101814"
|
|
| 8 |
[openai]
|
| 9 |
openai_api_key_parameter_store_path = "/secrets/openai/api_key"
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
[linkedin]
|
| 12 |
linkedin_api_key_parameter_store_path = "/secrets/rapidapi/linkedin"
|
|
|
|
| 8 |
[openai]
|
| 9 |
openai_api_key_parameter_store_path = "/secrets/openai/api_key"
|
| 10 |
|
| 11 |
+
[openrouter]
|
| 12 |
+
openrouter_api_key_parameter_store_path = "/secrets/openrouter/api_key"
|
| 13 |
+
|
| 14 |
[linkedin]
|
| 15 |
linkedin_api_key_parameter_store_path = "/secrets/rapidapi/linkedin"
|
tests/vsp/llm/{__ini__.py → __init__.py}
RENAMED
|
File without changes
|
tests/vsp/llm/openrouter/__init__.py
ADDED
|
File without changes
|
tests/vsp/llm/openrouter/test_integration_openrouter.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from vsp.llm.openrouter.openrouter import AsyncOpenRouterService
|
| 6 |
+
from vsp.shared import logger_factory
|
| 7 |
+
|
| 8 |
+
logger = logger_factory.get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@pytest.mark.asyncio
|
| 12 |
+
async def test_openrouter_integration():
|
| 13 |
+
"""
|
| 14 |
+
Integration test for AsyncOpenRouterService.
|
| 15 |
+
|
| 16 |
+
This test makes an actual API call to OpenRouter using the
|
| 17 |
+
perplexity/llama-3.1-sonar-huge-128k-online model. It requires a valid
|
| 18 |
+
OpenRouter API key to be set in the AWS Parameter Store.
|
| 19 |
+
|
| 20 |
+
Note: This test should be run sparingly to avoid unnecessary API calls
|
| 21 |
+
and potential costs.
|
| 22 |
+
"""
|
| 23 |
+
model = "nousresearch/hermes-3-llama-3.1-405b:free"
|
| 24 |
+
service = AsyncOpenRouterService(model)
|
| 25 |
+
|
| 26 |
+
async with service() as openrouter:
|
| 27 |
+
try:
|
| 28 |
+
response = await openrouter.invoke(
|
| 29 |
+
user_prompt="What is the capital of France?", max_tokens=100, temperature=0.7
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Log the response
|
| 33 |
+
logger.info("OpenRouter API Response", response=response)
|
| 34 |
+
|
| 35 |
+
# Assertions to verify the response
|
| 36 |
+
assert response is not None
|
| 37 |
+
assert isinstance(response, str)
|
| 38 |
+
assert len(response) > 0
|
| 39 |
+
assert "Paris" in response
|
| 40 |
+
|
| 41 |
+
logger.info("Integration test passed successfully")
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.error("Integration test failed", error=str(e))
|
| 44 |
+
raise
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
asyncio.run(test_openrouter_integration())
|