|
|
import asyncio |
|
|
import json |
|
|
from contextlib import asynccontextmanager |
|
|
from typing import AsyncIterator |
|
|
|
|
|
import aiohttp |
|
|
from botocore.auth import SigV4Auth |
|
|
from botocore.awsrequest import AWSRequest |
|
|
|
|
|
from vsp.llm.bedrock.bedrock_model import AnthropicModel, BedrockModel, get_bedrock_model_rate_limit |
|
|
from vsp.llm.bedrock.bedrock_rate_limiter import BedrockRateLimiter |
|
|
from vsp.llm.llm_service import LLMService |
|
|
from vsp.shared import aws_clients, config, logger_factory |
|
|
|
|
|
logger = logger_factory.get_logger(__name__) |
|
|
|
|
|
|
|
|
class AsyncBedrockService(LLMService): |
|
|
""" |
|
|
An asynchronous service class for making calls to the AWS Bedrock API. |
|
|
|
|
|
This class handles authentication, rate limiting, and API interactions |
|
|
when using AWS Bedrock's language models. |
|
|
""" |
|
|
|
|
|
def __init__(self, model: BedrockModel, max_concurrency: int = 10): |
|
|
""" |
|
|
Initialize the AsyncBedrockService. |
|
|
|
|
|
Args: |
|
|
model (BedrockModel): The Bedrock model to use for API calls. |
|
|
max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10. |
|
|
""" |
|
|
bedrock_client = aws_clients.get_bedrock_client() |
|
|
self._session: aiohttp.ClientSession | None = None |
|
|
self._credentials = bedrock_client._get_credentials() |
|
|
self._region = config.get_aws_region() |
|
|
self._semaphore = asyncio.Semaphore(max_concurrency) |
|
|
self._model = model |
|
|
self._rate_limiter = BedrockRateLimiter( |
|
|
rate=max(get_bedrock_model_rate_limit(model).requests_per_minute - 10, 1), per=60.0 |
|
|
) |
|
|
|
|
|
@asynccontextmanager |
|
|
async def __call__(self) -> AsyncIterator["AsyncBedrockService"]: |
|
|
self._session = aiohttp.ClientSession() |
|
|
try: |
|
|
yield self |
|
|
finally: |
|
|
if self._session: |
|
|
await self._session.close() |
|
|
|
|
|
async def invoke( |
|
|
self, |
|
|
user_prompt: str | None = None, |
|
|
system_prompt: str | None = None, |
|
|
partial_assistant_prompt: str | None = None, |
|
|
max_tokens: int = 1000, |
|
|
temperature: float = 0.0, |
|
|
) -> str | None: |
|
|
""" |
|
|
Invoke the Bedrock API with the given prompts and parameters. |
|
|
|
|
|
This method handles rate limiting and makes the API call. |
|
|
|
|
|
Args: |
|
|
user_prompt (str | None): The main prompt from the user. |
|
|
system_prompt (str | None): A system message to set the context. |
|
|
partial_assistant_prompt (str | None): A partial response from the assistant. |
|
|
max_tokens (int): Maximum number of tokens in the response. |
|
|
temperature (float): Sampling temperature for response generation. |
|
|
|
|
|
Returns: |
|
|
str | None: The generated response from the Bedrock API, or None if no response. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the model is not an Anthropic model or if no prompts are provided. |
|
|
Exception: For any errors encountered during the API call. |
|
|
""" |
|
|
|
|
|
if not isinstance(self._model, AnthropicModel): |
|
|
raise ValueError(f"Model {self._model} is not an Anthropic model") |
|
|
|
|
|
|
|
|
if not any([user_prompt, system_prompt, partial_assistant_prompt]): |
|
|
raise ValueError("At least one of user_prompt, system_prompt, or partial_assistant_prompt must be provided") |
|
|
|
|
|
async with self._semaphore: |
|
|
await self._rate_limiter.acquire() |
|
|
url = f"https://bedrock-runtime.{self._region}.amazonaws.com/model/{self._model.value}/invoke" |
|
|
if user_prompt: |
|
|
messages = [{"role": "user", "content": user_prompt}] |
|
|
if partial_assistant_prompt: |
|
|
messages.append({"role": "assistant", "content": partial_assistant_prompt}) |
|
|
|
|
|
body = { |
|
|
"anthropic_version": "bedrock-2023-05-31", |
|
|
"max_tokens": max_tokens, |
|
|
"messages": messages, |
|
|
"temperature": temperature, |
|
|
} |
|
|
|
|
|
if system_prompt: |
|
|
body["system"] = system_prompt |
|
|
|
|
|
body_json = json.dumps(body) |
|
|
|
|
|
headers = { |
|
|
"Content-Type": "application/json", |
|
|
"Accept": "application/json", |
|
|
} |
|
|
|
|
|
request = AWSRequest(method="POST", url=url, data=body_json, headers=headers) |
|
|
SigV4Auth(self._credentials, "bedrock", self._region).add_auth(request) |
|
|
|
|
|
if self._session is None: |
|
|
raise RuntimeError("Session is not initialized") |
|
|
|
|
|
async with self._session.post(url, data=body_json, headers=dict(request.headers)) as response: |
|
|
logger.info("Bedrock API called", url=response.url) |
|
|
if response.status != 200: |
|
|
raise Exception(f"Bedrock API error: {response.status} {await response.text()}") |
|
|
|
|
|
response_body = await response.json() |
|
|
|
|
|
if "content" not in response_body: |
|
|
raise Exception(f"Content not found in Bedrock response: {response_body}") |
|
|
|
|
|
text = str(response_body["content"][0]["text"]) |
|
|
if partial_assistant_prompt: |
|
|
text = f"{partial_assistant_prompt}{text}" |
|
|
|
|
|
|
|
|
input_tokens = response_body.get("usage", {}).get("input_tokens", 0) |
|
|
output_tokens = response_body.get("usage", {}).get("output_tokens", 0) |
|
|
|
|
|
|
|
|
logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens) |
|
|
|
|
|
return text |
|
|
|