navkast commited on
Commit
a1ead4c
·
unverified ·
1 Parent(s): 6abd06b

Create openrouter client (#7)

Browse files

* feat: Implement AsyncOpenRouterService for OpenAI API integration

* fix: Add type annotations and ignore untyped function calls in openrouter.py

* Create openrouter client

src/vsp/llm/openrouter/openrouter.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ from contextlib import asynccontextmanager
4
+ from typing import AsyncIterator
5
+
6
+ import openai
7
+ from openai import AsyncOpenAI
8
+ from openai.types.chat import ChatCompletion
9
+ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
10
+
11
+ from vsp.llm.llm_service import LLMService, RateLimitError
12
+ from vsp.shared import aws_clients, config, logger_factory
13
+
14
+ logger = logger_factory.get_logger(__name__)
15
+
16
+
17
+ class AsyncOpenRouterService(LLMService):
18
+ """
19
+ An asynchronous service class for making calls to the OpenRouter API.
20
+
21
+ This class handles authentication and implements a basic rate limiting strategy.
22
+ """
23
+
24
+ def __init__(self, model: str, max_concurrency: int = 3, requests_per_minute: int = 60):
25
+ """
26
+ Initialize the AsyncOpenRouterService.
27
+
28
+ Args:
29
+ model (str): The OpenRouter model to use for API calls.
30
+ max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10.
31
+ requests_per_minute (int): Maximum number of requests allowed per minute. Defaults to 60.
32
+ """
33
+ self._client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self._fetch_api_key())
34
+ self._semaphore = asyncio.Semaphore(max_concurrency)
35
+ self._model = model
36
+ self._requests_per_minute = requests_per_minute
37
+ self._request_times: list[float] = []
38
+
39
+ @staticmethod
40
+ def _fetch_api_key() -> str:
41
+ """
42
+ Fetch the OpenRouter API key from AWS Parameter Store.
43
+
44
+ Returns:
45
+ str: The OpenRouter API key.
46
+
47
+ Raises:
48
+ ValueError: If the API key is not found in the Parameter Store.
49
+ RuntimeError: If there's an error accessing the Parameter Store.
50
+ """
51
+ try:
52
+ return aws_clients.fetch_from_parameter_store(config.get_openrouter_api_key_path(), is_secret=True)
53
+ except aws_clients.ParameterNotFoundError as e:
54
+ logger.error("OpenRouter API key not found in Parameter Store", error=str(e))
55
+ raise ValueError("OpenRouter API key not found") from e
56
+ except aws_clients.ParameterStoreAccessError as e:
57
+ logger.error("Error accessing Parameter Store", error=str(e))
58
+ raise RuntimeError("Unable to access OpenRouter API key") from e
59
+
60
+ @asynccontextmanager
61
+ async def __call__(self) -> AsyncIterator["AsyncOpenRouterService"]:
62
+ try:
63
+ yield self
64
+ finally:
65
+ await self._client.close()
66
+
67
+ async def invoke(
68
+ self,
69
+ user_prompt: str | None = None,
70
+ system_prompt: str | None = None,
71
+ partial_assistant_prompt: str | None = None,
72
+ max_tokens: int = 1000,
73
+ temperature: float = 0.0,
74
+ ) -> str | None:
75
+ """
76
+ Invoke the OpenRouter API with the given prompts and parameters.
77
+
78
+ This method handles rate limiting and makes the API call.
79
+
80
+ Args:
81
+ user_prompt (str | None): The main prompt from the user.
82
+ system_prompt (str | None): A system message to set the context.
83
+ partial_assistant_prompt (str | None): A partial response from the assistant.
84
+ max_tokens (int): Maximum number of tokens in the response.
85
+ temperature (float): Sampling temperature for response generation.
86
+
87
+ Returns:
88
+ str | None: The generated response from the OpenRouter API, or None if no response.
89
+
90
+ Raises:
91
+ RateLimitError: If the API rate limit is exceeded.
92
+ openai.APIError: For any other errors encountered during the API call.
93
+ """
94
+ async with self._semaphore: # Use semaphore to limit concurrency
95
+ await self._wait_for_rate_limit()
96
+
97
+ messages = []
98
+ if system_prompt:
99
+ messages.append({"role": "system", "content": system_prompt})
100
+ if user_prompt:
101
+ messages.append({"role": "user", "content": user_prompt})
102
+ if partial_assistant_prompt:
103
+ messages.append({"role": "assistant", "content": partial_assistant_prompt})
104
+
105
+ response = await self.query_openrouter(max_tokens, temperature, messages)
106
+ logger.info("OpenRouter API called", model=self._model)
107
+
108
+ self._update_request_times()
109
+
110
+ message = response.choices[0].message
111
+ text = str(message.content)
112
+ if text is None:
113
+ logger.warn("No message content from OpenRouter API")
114
+ return None
115
+
116
+ if partial_assistant_prompt:
117
+ text = f"{partial_assistant_prompt}{text}"
118
+
119
+ # Extract token usage information
120
+ usage = response.usage
121
+ input_tokens = usage.prompt_tokens if usage else 0
122
+ output_tokens = usage.completion_tokens if usage else 0
123
+
124
+ # Log token usage
125
+ logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens)
126
+
127
+ return text
128
+
129
+ @retry(
130
+ wait=wait_random_exponential(min=1, max=60),
131
+ stop=stop_after_attempt(6),
132
+ retry=retry_if_exception_type(RateLimitError),
133
+ ) # type: ignore
134
+ async def query_openrouter(self, max_tokens, temperature, messages) -> ChatCompletion:
135
+ try:
136
+ response = await self._client.chat.completions.create(
137
+ model=self._model,
138
+ messages=messages,
139
+ max_tokens=max_tokens,
140
+ temperature=temperature,
141
+ )
142
+
143
+ return response
144
+
145
+ except openai.RateLimitError as e:
146
+ logger.warning("Rate limit error encountered. Retrying...")
147
+ raise RateLimitError("OpenRouter API rate limit exceeded") from e
148
+ except openai.APIError as e:
149
+ logger.error("OpenRouter API error", error=str(e))
150
+ raise
151
+
152
+ def _update_request_times(self) -> None:
153
+ """
154
+ Update the list of request times, removing any that are older than one minute.
155
+ """
156
+ current_time = time.time()
157
+ self._request_times = [t for t in self._request_times if current_time - t < 60]
158
+ self._request_times.append(current_time)
159
+
160
+ async def _wait_for_rate_limit(self) -> None:
161
+ """
162
+ Wait if necessary to respect the rate limit.
163
+ """
164
+ while len(self._request_times) >= self._requests_per_minute:
165
+ current_time = time.time()
166
+ oldest_request_time = self._request_times[0]
167
+ if current_time - oldest_request_time < 60:
168
+ wait_time = 60 - (current_time - oldest_request_time)
169
+ logger.info(f"Rate limit reached. Waiting for {wait_time:.2f} seconds.")
170
+ await asyncio.sleep(wait_time)
171
+ self._request_times = [t for t in self._request_times if current_time - t < 60]
172
+ logger.debug(f"Requests in the last minute: {len(self._request_times)}")
src/vsp/shared/config.py CHANGED
@@ -93,6 +93,16 @@ def get_openai_api_key_path() -> str:
93
  return str(config["openai"]["openai_api_key_parameter_store_path"])
94
 
95
 
 
 
 
 
 
 
 
 
 
 
96
  @cache
97
  def get_linkedin_key_path() -> str:
98
  """
 
93
  return str(config["openai"]["openai_api_key_parameter_store_path"])
94
 
95
 
96
+ @cache
97
+ def get_openrouter_api_key_path() -> str:
98
+ """
99
+ Reads the OpenRouter API key path from the TOML configuration file.
100
+ Key is in AWS parameter store
101
+ """
102
+ config = _get_config()
103
+ return str(config["openrouter"]["openrouter_api_key_parameter_store_path"])
104
+
105
+
106
  @cache
107
  def get_linkedin_key_path() -> str:
108
  """
src/vsp/shared/config.toml CHANGED
@@ -8,5 +8,8 @@ bedrock_aws_account = "339713101814"
8
  [openai]
9
  openai_api_key_parameter_store_path = "/secrets/openai/api_key"
10
 
 
 
 
11
  [linkedin]
12
  linkedin_api_key_parameter_store_path = "/secrets/rapidapi/linkedin"
 
8
  [openai]
9
  openai_api_key_parameter_store_path = "/secrets/openai/api_key"
10
 
11
+ [openrouter]
12
+ openrouter_api_key_parameter_store_path = "/secrets/openrouter/api_key"
13
+
14
  [linkedin]
15
  linkedin_api_key_parameter_store_path = "/secrets/rapidapi/linkedin"
tests/vsp/llm/{__ini__.py → __init__.py} RENAMED
File without changes
tests/vsp/llm/openrouter/__init__.py ADDED
File without changes
tests/vsp/llm/openrouter/test_integration_openrouter.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import pytest
4
+
5
+ from vsp.llm.openrouter.openrouter import AsyncOpenRouterService
6
+ from vsp.shared import logger_factory
7
+
8
+ logger = logger_factory.get_logger(__name__)
9
+
10
+
11
+ @pytest.mark.asyncio
12
+ async def test_openrouter_integration():
13
+ """
14
+ Integration test for AsyncOpenRouterService.
15
+
16
+ This test makes an actual API call to OpenRouter using the
17
+ perplexity/llama-3.1-sonar-huge-128k-online model. It requires a valid
18
+ OpenRouter API key to be set in the AWS Parameter Store.
19
+
20
+ Note: This test should be run sparingly to avoid unnecessary API calls
21
+ and potential costs.
22
+ """
23
+ model = "nousresearch/hermes-3-llama-3.1-405b:free"
24
+ service = AsyncOpenRouterService(model)
25
+
26
+ async with service() as openrouter:
27
+ try:
28
+ response = await openrouter.invoke(
29
+ user_prompt="What is the capital of France?", max_tokens=100, temperature=0.7
30
+ )
31
+
32
+ # Log the response
33
+ logger.info("OpenRouter API Response", response=response)
34
+
35
+ # Assertions to verify the response
36
+ assert response is not None
37
+ assert isinstance(response, str)
38
+ assert len(response) > 0
39
+ assert "Paris" in response
40
+
41
+ logger.info("Integration test passed successfully")
42
+ except Exception as e:
43
+ logger.error("Integration test failed", error=str(e))
44
+ raise
45
+
46
+
47
+ if __name__ == "__main__":
48
+ asyncio.run(test_openrouter_integration())