File size: 5,746 Bytes
3b993c4
 
 
 
 
 
 
 
 
24d33b9
 
 
 
3b993c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import asyncio
import json
from contextlib import asynccontextmanager
from typing import AsyncIterator

import aiohttp
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest

from vsp.llm.bedrock.bedrock_model import AnthropicModel, BedrockModel, get_bedrock_model_rate_limit
from vsp.llm.bedrock.bedrock_rate_limiter import BedrockRateLimiter
from vsp.llm.llm_service import LLMService
from vsp.shared import aws_clients, config, logger_factory

logger = logger_factory.get_logger(__name__)


class AsyncBedrockService(LLMService):
    """
    An asynchronous service class for making calls to the AWS Bedrock API.

    This class handles authentication, rate limiting, and API interactions
    when using AWS Bedrock's language models.
    """

    def __init__(self, model: BedrockModel, max_concurrency: int = 10):
        """
        Initialize the AsyncBedrockService.

        Args:
            model (BedrockModel): The Bedrock model to use for API calls.
            max_concurrency (int): Maximum number of concurrent API calls. Defaults to 10.
        """
        bedrock_client = aws_clients.get_bedrock_client()
        self._session: aiohttp.ClientSession | None = None
        self._credentials = bedrock_client._get_credentials()
        self._region = config.get_aws_region()
        self._semaphore = asyncio.Semaphore(max_concurrency)
        self._model = model
        self._rate_limiter = BedrockRateLimiter(
            rate=max(get_bedrock_model_rate_limit(model).requests_per_minute - 10, 1), per=60.0
        )

    @asynccontextmanager
    async def __call__(self) -> AsyncIterator["AsyncBedrockService"]:
        self._session = aiohttp.ClientSession()
        try:
            yield self
        finally:
            if self._session:
                await self._session.close()

    async def invoke(
        self,
        user_prompt: str | None = None,
        system_prompt: str | None = None,
        partial_assistant_prompt: str | None = None,
        max_tokens: int = 1000,
        temperature: float = 0.0,
    ) -> str | None:
        """
        Invoke the Bedrock API with the given prompts and parameters.

        This method handles rate limiting and makes the API call.

        Args:
            user_prompt (str | None): The main prompt from the user.
            system_prompt (str | None): A system message to set the context.
            partial_assistant_prompt (str | None): A partial response from the assistant.
            max_tokens (int): Maximum number of tokens in the response.
            temperature (float): Sampling temperature for response generation.

        Returns:
            str | None: The generated response from the Bedrock API, or None if no response.

        Raises:
            ValueError: If the model is not an Anthropic model or if no prompts are provided.
            Exception: For any errors encountered during the API call.
        """
        # Verify that the model is an Anthropic model
        if not isinstance(self._model, AnthropicModel):
            raise ValueError(f"Model {self._model} is not an Anthropic model")

        # Verify that user prompt, system prompt, and partial assistant prompt are not all None
        if not any([user_prompt, system_prompt, partial_assistant_prompt]):
            raise ValueError("At least one of user_prompt, system_prompt, or partial_assistant_prompt must be provided")

        async with self._semaphore:  # Use semaphore to limit concurrency
            await self._rate_limiter.acquire()  # rate limit first
            url = f"https://bedrock-runtime.{self._region}.amazonaws.com/model/{self._model.value}/invoke"
            if user_prompt:
                messages = [{"role": "user", "content": user_prompt}]
            if partial_assistant_prompt:
                messages.append({"role": "assistant", "content": partial_assistant_prompt})

            body = {
                "anthropic_version": "bedrock-2023-05-31",
                "max_tokens": max_tokens,
                "messages": messages,
                "temperature": temperature,
            }

            if system_prompt:
                body["system"] = system_prompt

            body_json = json.dumps(body)

            headers = {
                "Content-Type": "application/json",
                "Accept": "application/json",
            }

            request = AWSRequest(method="POST", url=url, data=body_json, headers=headers)
            SigV4Auth(self._credentials, "bedrock", self._region).add_auth(request)

            if self._session is None:
                raise RuntimeError("Session is not initialized")

            async with self._session.post(url, data=body_json, headers=dict(request.headers)) as response:
                logger.info("Bedrock API called", url=response.url)
                if response.status != 200:
                    raise Exception(f"Bedrock API error: {response.status} {await response.text()}")

                response_body = await response.json()

                if "content" not in response_body:
                    raise Exception(f"Content not found in Bedrock response: {response_body}")

                text = str(response_body["content"][0]["text"])
                if partial_assistant_prompt:
                    text = f"{partial_assistant_prompt}{text}"

                # Extract token usage information
                input_tokens = response_body.get("usage", {}).get("input_tokens", 0)
                output_tokens = response_body.get("usage", {}).get("output_tokens", 0)

                # Log token usage
                logger.info("Token usage", input_tokens=input_tokens, output_tokens=output_tokens)

                return text