File size: 5,746 Bytes
3b993c4 24d33b9 3b993c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import asyncio
import json
from contextlib import asynccontextmanager
from typing import AsyncIterator
import aiohttp
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from vsp.llm.bedrock.bedrock_model import AnthropicModel, BedrockModel, get_bedrock_model_rate_limit
from vsp.llm.bedrock.bedrock_rate_limiter import BedrockRateLimiter
from vsp.llm.llm_service import LLMService
from vsp.shared import aws_clients, config, logger_factory
logger = logger_factory.get_logger(__name__)
class AsyncBedrockService(LLMService):
"""
An asynchronous service class for making calls to the AWS Bedrock API.
This class handles authentication, rate limiting, and API interactions
when using AWS Bedrock's language models.
"""
def __init__(self, model: BedrockModel, max_concurrency: int = 10):
"""
Initialize the AsyncBedrockService.
Args:
model (BedrockModel): The Bedrock model to use for API calls.
max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10.
"""
bedrock_client = aws_clients.get_bedrock_client()
self._session: aiohttp.ClientSession | None = None
self._credentials = bedrock_client._get_credentials()
self._region = config.get_aws_region()
self._semaphore = asyncio.Semaphore(max_concurrency)
self._model = model
self._rate_limiter = BedrockRateLimiter(
rate=max(get_bedrock_model_rate_limit(model).requests_per_minute - 10, 1), per=60.0
)
@asynccontextmanager
async def __call__(self) -> AsyncIterator["AsyncBedrockService"]:
self._session = aiohttp.ClientSession()
try:
yield self
finally:
if self._session:
await self._session.close()
async def invoke(
self,
user_prompt: str | None = None,
system_prompt: str | None = None,
partial_assistant_prompt: str | None = None,
max_tokens: int = 1000,
temperature: float = 0.0,
) -> str | None:
"""
Invoke the Bedrock API with the given prompts and parameters.
This method handles rate limiting and makes the API call.
Args:
user_prompt (str | None): The main prompt from the user.
system_prompt (str | None): A system message to set the context.
partial_assistant_prompt (str | None): A partial response from the assistant.
max_tokens (int): Maximum number of tokens in the response.
temperature (float): Sampling temperature for response generation.
Returns:
str | None: The generated response from the Bedrock API, or None if no response.
Raises:
ValueError: If the model is not an Anthropic model or if no prompts are provided.
Exception: For any errors encountered during the API call.
"""
# Verify that the model is an Anthropic model
if not isinstance(self._model, AnthropicModel):
raise ValueError(f"Model {self._model} is not an Anthropic model")
# Verify that user prompt, system prompt, and partial assistant prompt are not all None
if not any([user_prompt, system_prompt, partial_assistant_prompt]):
raise ValueError("At least one of user_prompt, system_prompt, or partial_assistant_prompt must be provided")
async with self._semaphore: # Use semaphore to limit concurrency
await self._rate_limiter.acquire() # rate limit first
url = f"https://bedrock-runtime.{self._region}.amazonaws.com/model/{self._model.value}/invoke"
if user_prompt:
messages = [{"role": "user", "content": user_prompt}]
if partial_assistant_prompt:
messages.append({"role": "assistant", "content": partial_assistant_prompt})
body = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"messages": messages,
"temperature": temperature,
}
if system_prompt:
body["system"] = system_prompt
body_json = json.dumps(body)
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
request = AWSRequest(method="POST", url=url, data=body_json, headers=headers)
SigV4Auth(self._credentials, "bedrock", self._region).add_auth(request)
if self._session is None:
raise RuntimeError("Session is not initialized")
async with self._session.post(url, data=body_json, headers=dict(request.headers)) as response:
logger.info("Bedrock API called", url=response.url)
if response.status != 200:
raise Exception(f"Bedrock API error: {response.status} {await response.text()}")
response_body = await response.json()
if "content" not in response_body:
raise Exception(f"Content not found in Bedrock response: {response_body}")
text = str(response_body["content"][0]["text"])
if partial_assistant_prompt:
text = f"{partial_assistant_prompt}{text}"
# Extract token usage information
input_tokens = response_body.get("usage", {}).get("input_tokens", 0)
output_tokens = response_body.get("usage", {}).get("output_tokens", 0)
# Log token usage
logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens)
return text
|