navkast
Update location of the VSP module (#1)
c1f8477 unverified
import asyncio
import json
from contextlib import asynccontextmanager
from typing import AsyncIterator
import aiohttp
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from vsp.llm.bedrock.bedrock_model import AnthropicModel, BedrockModel, get_bedrock_model_rate_limit
from vsp.llm.bedrock.bedrock_rate_limiter import BedrockRateLimiter
from vsp.llm.llm_service import LLMService
from vsp.shared import aws_clients, config, logger_factory
logger = logger_factory.get_logger(__name__)
class AsyncBedrockService(LLMService):
"""
An asynchronous service class for making calls to the AWS Bedrock API.
This class handles authentication, rate limiting, and API interactions
when using AWS Bedrock's language models.
"""
def __init__(self, model: BedrockModel, max_concurrency: int = 10):
"""
Initialize the AsyncBedrockService.
Args:
model (BedrockModel): The Bedrock model to use for API calls.
max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10.
"""
bedrock_client = aws_clients.get_bedrock_client()
self._session: aiohttp.ClientSession | None = None
self._credentials = bedrock_client._get_credentials()
self._region = config.get_aws_region()
self._semaphore = asyncio.Semaphore(max_concurrency)
self._model = model
self._rate_limiter = BedrockRateLimiter(
rate=max(get_bedrock_model_rate_limit(model).requests_per_minute - 10, 1), per=60.0
)
@asynccontextmanager
async def __call__(self) -> AsyncIterator["AsyncBedrockService"]:
self._session = aiohttp.ClientSession()
try:
yield self
finally:
if self._session:
await self._session.close()
async def invoke(
self,
user_prompt: str | None = None,
system_prompt: str | None = None,
partial_assistant_prompt: str | None = None,
max_tokens: int = 1000,
temperature: float = 0.0,
) -> str | None:
"""
Invoke the Bedrock API with the given prompts and parameters.
This method handles rate limiting and makes the API call.
Args:
user_prompt (str | None): The main prompt from the user.
system_prompt (str | None): A system message to set the context.
partial_assistant_prompt (str | None): A partial response from the assistant.
max_tokens (int): Maximum number of tokens in the response.
temperature (float): Sampling temperature for response generation.
Returns:
str | None: The generated response from the Bedrock API, or None if no response.
Raises:
ValueError: If the model is not an Anthropic model or if no prompts are provided.
Exception: For any errors encountered during the API call.
"""
# Verify that the model is an Anthropic model
if not isinstance(self._model, AnthropicModel):
raise ValueError(f"Model {self._model} is not an Anthropic model")
# Verify that user prompt, system prompt, and partial assistant prompt are not all None
if not any([user_prompt, system_prompt, partial_assistant_prompt]):
raise ValueError("At least one of user_prompt, system_prompt, or partial_assistant_prompt must be provided")
async with self._semaphore: # Use semaphore to limit concurrency
await self._rate_limiter.acquire() # rate limit first
url = f"https://bedrock-runtime.{self._region}.amazonaws.com/model/{self._model.value}/invoke"
if user_prompt:
messages = [{"role": "user", "content": user_prompt}]
if partial_assistant_prompt:
messages.append({"role": "assistant", "content": partial_assistant_prompt})
body = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"messages": messages,
"temperature": temperature,
}
if system_prompt:
body["system"] = system_prompt
body_json = json.dumps(body)
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
request = AWSRequest(method="POST", url=url, data=body_json, headers=headers)
SigV4Auth(self._credentials, "bedrock", self._region).add_auth(request)
if self._session is None:
raise RuntimeError("Session is not initialized")
async with self._session.post(url, data=body_json, headers=dict(request.headers)) as response:
logger.info("Bedrock API called", url=response.url)
if response.status != 200:
raise Exception(f"Bedrock API error: {response.status} {await response.text()}")
response_body = await response.json()
if "content" not in response_body:
raise Exception(f"Content not found in Bedrock response: {response_body}")
text = str(response_body["content"][0]["text"])
if partial_assistant_prompt:
text = f"{partial_assistant_prompt}{text}"
# Extract token usage information
input_tokens = response_body.get("usage", {}).get("input_tokens", 0)
output_tokens = response_body.get("usage", {}).get("output_tokens", 0)
# Log token usage
logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens)
return text