import asyncio import json from contextlib import asynccontextmanager from typing import AsyncIterator import aiohttp from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest from vsp.llm.bedrock.bedrock_model import AnthropicModel, BedrockModel, get_bedrock_model_rate_limit from vsp.llm.bedrock.bedrock_rate_limiter import BedrockRateLimiter from vsp.llm.llm_service import LLMService from vsp.shared import aws_clients, config, logger_factory logger = logger_factory.get_logger(__name__) class AsyncBedrockService(LLMService): """ An asynchronous service class for making calls to the AWS Bedrock API. This class handles authentication, rate limiting, and API interactions when using AWS Bedrock's language models. """ def __init__(self, model: BedrockModel, max_concurrency: int = 10): """ Initialize the AsyncBedrockService. Args: model (BedrockModel): The Bedrock model to use for API calls. max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10. """ bedrock_client = aws_clients.get_bedrock_client() self._session: aiohttp.ClientSession | None = None self._credentials = bedrock_client._get_credentials() self._region = config.get_aws_region() self._semaphore = asyncio.Semaphore(max_concurrency) self._model = model self._rate_limiter = BedrockRateLimiter( rate=max(get_bedrock_model_rate_limit(model).requests_per_minute - 10, 1), per=60.0 ) @asynccontextmanager async def __call__(self) -> AsyncIterator["AsyncBedrockService"]: self._session = aiohttp.ClientSession() try: yield self finally: if self._session: await self._session.close() async def invoke( self, user_prompt: str | None = None, system_prompt: str | None = None, partial_assistant_prompt: str | None = None, max_tokens: int = 1000, temperature: float = 0.0, ) -> str | None: """ Invoke the Bedrock API with the given prompts and parameters. This method handles rate limiting and makes the API call. Args: user_prompt (str | None): The main prompt from the user. system_prompt (str | None): A system message to set the context. partial_assistant_prompt (str | None): A partial response from the assistant. max_tokens (int): Maximum number of tokens in the response. temperature (float): Sampling temperature for response generation. Returns: str | None: The generated response from the Bedrock API, or None if no response. Raises: ValueError: If the model is not an Anthropic model or if no prompts are provided. Exception: For any errors encountered during the API call. """ # Verify that the model is an Anthropic model if not isinstance(self._model, AnthropicModel): raise ValueError(f"Model {self._model} is not an Anthropic model") # Verify that user prompt, system prompt, and partial assistant prompt are not all None if not any([user_prompt, system_prompt, partial_assistant_prompt]): raise ValueError("At least one of user_prompt, system_prompt, or partial_assistant_prompt must be provided") async with self._semaphore: # Use semaphore to limit concurrency await self._rate_limiter.acquire() # rate limit first url = f"https://bedrock-runtime.{self._region}.amazonaws.com/model/{self._model.value}/invoke" if user_prompt: messages = [{"role": "user", "content": user_prompt}] if partial_assistant_prompt: messages.append({"role": "assistant", "content": partial_assistant_prompt}) body = { "anthropic_version": "bedrock-2023-05-31", "max_tokens": max_tokens, "messages": messages, "temperature": temperature, } if system_prompt: body["system"] = system_prompt body_json = json.dumps(body) headers = { "Content-Type": "application/json", "Accept": "application/json", } request = AWSRequest(method="POST", url=url, data=body_json, headers=headers) SigV4Auth(self._credentials, "bedrock", self._region).add_auth(request) if self._session is None: raise RuntimeError("Session is not initialized") async with self._session.post(url, data=body_json, headers=dict(request.headers)) as response: logger.info("Bedrock API called", url=response.url) if response.status != 200: raise Exception(f"Bedrock API error: {response.status} {await response.text()}") response_body = await response.json() if "content" not in response_body: raise Exception(f"Content not found in Bedrock response: {response_body}") text = str(response_body["content"][0]["text"]) if partial_assistant_prompt: text = f"{partial_assistant_prompt}{text}" # Extract token usage information input_tokens = response_body.get("usage", {}).get("input_tokens", 0) output_tokens = response_body.get("usage", {}).get("output_tokens", 0) # Log token usage logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens) return text