vsp-demo / src /vsp /llm /bedrock /bedrock_model.py
navkast
Update location of the VSP module (#1)
c1f8477 unverified
from dataclasses import dataclass
from enum import Enum
@dataclass
class RateLimit:
"""
Dataclass representing rate limit information for Bedrock models.
Attributes:
requests_per_minute (int): The maximum number of requests allowed per minute.
tokens_per_minute (int | None): The maximum number of tokens allowed per minute, if applicable.
regions (str): The AWS regions where this rate limit applies.
"""
requests_per_minute: int
tokens_per_minute: int | None
regions: str
class BedrockModel(Enum):
"""
Base enum class for Bedrock models.
This class defines the interface for Bedrock model enums and provides
a method to get rate limits for specific models.
Methods:
get_rate_limit(model: BedrockModel) -> RateLimit:
Get the rate limit for a specific model.
"""
@classmethod
def get_rate_limit(cls, model: "BedrockModel") -> RateLimit:
"""
Get the rate limit for a specific Bedrock model.
Args:
model (BedrockModel): The Bedrock model to get the rate limit for.
Returns:
RateLimit: The rate limit information for the specified model.
"""
return cls._rate_limits()[model]
@classmethod
def _rate_limits(cls) -> dict["BedrockModel", RateLimit]:
"""
Define the rate limits for each Bedrock model.
This method must be implemented by subclasses.
Returns:
dict[BedrockModel, RateLimit]: A dictionary mapping each model to its rate limit.
Raises:
NotImplementedError: If not implemented by a subclass.
"""
raise NotImplementedError("Subclasses must implement this method")
class AnthropicModel(BedrockModel):
"""
Enum representing different Anthropic models available through Bedrock.
This enum includes various versions of Claude models.
"""
CLAUDE_3_OPUS = "anthropic.claude-3-opus-20240229-v1:0"
CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0"
CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"
CLAUDE_3_HAIKU = "anthropic.claude-3-haiku-20240307-v1:0"
CLAUDE_INSTANT_1_2 = "anthropic.claude-instant-v1"
@classmethod
def _rate_limits(cls) -> dict[BedrockModel, RateLimit]:
return {
cls.CLAUDE_3_OPUS: RateLimit(50, 400_000, "All"),
cls.CLAUDE_3_5_SONNET: RateLimit(50, 400_000, "All"),
cls.CLAUDE_3_SONNET: RateLimit(
500, 1_000_000, "US East (N. Virginia) (us-east-1), US West (Oregon) (us-west-2)"
),
cls.CLAUDE_3_HAIKU: RateLimit(
1000, 2_000_000, "US East (N. Virginia) (us-east-1), US West (Oregon) (us-west-2)"
),
cls.CLAUDE_INSTANT_1_2: RateLimit(
1000, 1_000_000, "US East (N. Virginia) (us-east-1), US West (Oregon) (us-west-2)"
),
}
class MetaModel(BedrockModel):
"""
Enum representing different Meta models available through Bedrock.
This enum includes various versions of Llama models.
"""
LLAMA_2_70B_CHAT = "meta.llama2-70b-chat-v1"
LLAMA_2_13B_CHAT = "meta.llama2-13b-chat-v1"
LLAMA_3_8B_INSTRUCT = "meta.llama3-8b-instruct-v1:0"
LLAMA_3_70B_INSTRUCT = "meta.llama3-70b-instruct-v1:0"
@classmethod
def _rate_limits(cls) -> dict[BedrockModel, RateLimit]:
return {
cls.LLAMA_2_70B_CHAT: RateLimit(400, 300_000, "All"),
cls.LLAMA_2_13B_CHAT: RateLimit(800, 300_000, "All"),
cls.LLAMA_3_8B_INSTRUCT: RateLimit(800, 300_000, "All"),
cls.LLAMA_3_70B_INSTRUCT: RateLimit(400, 300_000, "All"),
}
class MistralModel(BedrockModel):
"""
Enum representing different Mistral models available through Bedrock.
This enum includes various versions of Mistral models.
"""
MISTRAL_7B_INSTRUCT = "mistral.mistral-7b-instruct-v0:2"
MIXTRAL_8X7B_INSTRUCT = "mistral.mixtral-8x7b-instruct-v0:1"
MISTRAL_LARGE = "mistral.mistral-large-2402-v1:0"
MISTRAL_SMALL = "mistral.mistral-small-2402-v1:0"
@classmethod
def _rate_limits(cls) -> dict[BedrockModel, RateLimit]:
return {
cls.MISTRAL_7B_INSTRUCT: RateLimit(800, 300_000, "All"),
cls.MIXTRAL_8X7B_INSTRUCT: RateLimit(400, 300_000, "All"),
cls.MISTRAL_LARGE: RateLimit(400, 300_000, "All"),
cls.MISTRAL_SMALL: RateLimit(400, 300_000, "All"),
}
# Function to get rate limit for any Bedrock model
def get_bedrock_model_rate_limit(model: BedrockModel) -> RateLimit:
"""
Get the rate limit for a specific Bedrock model.
This is a convenience function that calls the get_rate_limit method of the appropriate BedrockModel subclass.
Args:
model (BedrockModel): The Bedrock model to get the rate limit for.
Returns:
RateLimit: The rate limit information for the specified model.
"""
return type(model).get_rate_limit(model)