|
|
import os |
|
|
from typing import Any, cast |
|
|
|
|
|
import boto3 |
|
|
import requests |
|
|
from boto3.resources.base import ServiceResource |
|
|
from botocore.client import BaseClient |
|
|
from botocore.exceptions import ClientError |
|
|
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient |
|
|
from mypy_boto3_dynamodb.service_resource import DynamoDBServiceResource |
|
|
from mypy_boto3_s3 import Client as S3Client |
|
|
from mypy_boto3_ssm import Client as SSMClient |
|
|
|
|
|
from vsp.shared import config, logger_factory |
|
|
import os |
|
|
|
|
|
logger = logger_factory.get_logger(__name__) |
|
|
|
|
|
USE_ENV_VAR_INSTEAD = True |
|
|
def path_to_env_var(path: str): |
|
|
|
|
|
env_var = path.lstrip('/').replace('/', '_').upper() |
|
|
return env_var |
|
|
|
|
|
|
|
|
def _get_session() -> boto3.Session: |
|
|
""" |
|
|
Creates and returns a boto3 session based on the execution environment. |
|
|
|
|
|
If running in ECS (determined by the presence of AWS_CONTAINER_CREDENTIALS_RELATIVE_URI |
|
|
environment variable), it returns a default boto3 session which will use the ECS task's |
|
|
IAM role. Otherwise, it returns a session using the "Geometric-PowerUserAccess" profile |
|
|
for local execution. |
|
|
|
|
|
Returns: |
|
|
boto3.Session: A boto3 session configured for the current execution environment. |
|
|
""" |
|
|
if "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" in os.environ: |
|
|
return boto3.Session() |
|
|
else: |
|
|
return boto3.Session(profile_name="Geometric-PowerUserAccess") |
|
|
|
|
|
|
|
|
class ECSCredentialsError(Exception): |
|
|
"""Raised when there's an error retrieving ECS task credentials.""" |
|
|
|
|
|
|
|
|
class RoleAssumptionError(Exception): |
|
|
"""Raised when there's an error assuming an IAM role.""" |
|
|
|
|
|
|
|
|
def get_credentials() -> dict[str, str]: |
|
|
""" |
|
|
Retrieves AWS credentials based on the execution environment. |
|
|
|
|
|
If running in ECS, it retrieves credentials from the ECS task's metadata. |
|
|
If running locally, it assumes the role specified in the configuration. |
|
|
|
|
|
Returns: |
|
|
dict[str, str]: A dictionary containing AccessKeyId, SecretAccessKey, and SessionToken. |
|
|
|
|
|
Raises: |
|
|
ECSCredentialsError: If there's an error retrieving ECS task credentials. |
|
|
RoleAssumptionError: If there's an error assuming the specified IAM role. |
|
|
ValueError: If the execution environment is not recognized. |
|
|
""" |
|
|
session = _get_session() |
|
|
if "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" in os.environ: |
|
|
logger.info("Using ECS task's IAM role for credentials") |
|
|
return _get_ecs_credentials() |
|
|
else: |
|
|
logger.info("Assuming role", role_arn=config.get_role_arn()) |
|
|
return _assume_role(session) |
|
|
|
|
|
|
|
|
def _get_ecs_credentials() -> dict[str, str]: |
|
|
ecs_creds_url = f"http://169.254.170.2{os.environ.get('AWS_CONTAINER_CREDENTIALS_RELATIVE_URI')}" |
|
|
response = requests.get(ecs_creds_url, timeout=5) |
|
|
if response.status_code == 200: |
|
|
creds = response.json() |
|
|
return { |
|
|
"AccessKeyId": creds["AccessKeyId"], |
|
|
"SecretAccessKey": creds["SecretAccessKey"], |
|
|
"SessionToken": creds["Token"], |
|
|
} |
|
|
else: |
|
|
raise ECSCredentialsError(f"Failed to retrieve ECS task credentials. Status code: {response.status_code}") |
|
|
|
|
|
|
|
|
def _assume_role(session: boto3.Session) -> dict[str, str]: |
|
|
try: |
|
|
sts_client = session.client("sts") |
|
|
assumed_role = sts_client.assume_role(RoleArn=config.get_role_arn(), RoleSessionName="AssumeRoleSession")[ |
|
|
"Credentials" |
|
|
] |
|
|
logger.info("Role assumed successfully") |
|
|
return { |
|
|
"AccessKeyId": assumed_role["AccessKeyId"], |
|
|
"SecretAccessKey": assumed_role["SecretAccessKey"], |
|
|
"SessionToken": assumed_role["SessionToken"], |
|
|
} |
|
|
except ClientError as e: |
|
|
raise RoleAssumptionError(f"Failed to assume role: {e}") from e |
|
|
|
|
|
|
|
|
def _get_boto3_client(service_name: str) -> BaseClient: |
|
|
""" |
|
|
Creates and returns a boto3 client for the specified AWS service. |
|
|
|
|
|
This function uses the session and credentials appropriate for the current |
|
|
execution environment (ECS or local). |
|
|
|
|
|
Args: |
|
|
service_name (str): The name of the AWS service for which to create a client. |
|
|
|
|
|
Returns: |
|
|
BaseClient: A boto3 client for the specified service. |
|
|
""" |
|
|
logger.info("Creating boto3 client", service=service_name) |
|
|
session = _get_session() |
|
|
credentials = get_credentials() |
|
|
kwargs: dict[str, Any] = {"region_name": config.get_aws_region(), "use_ssl": True} |
|
|
if credentials: |
|
|
kwargs.update( |
|
|
{ |
|
|
"aws_access_key_id": credentials["AccessKeyId"], |
|
|
"aws_secret_access_key": credentials["SecretAccessKey"], |
|
|
"aws_session_token": credentials["SessionToken"], |
|
|
} |
|
|
) |
|
|
return session.client(service_name, **kwargs) |
|
|
|
|
|
|
|
|
def _get_boto3_resource(service_name: str) -> ServiceResource: |
|
|
""" |
|
|
Creates and returns a boto3 resource for the specified AWS service. |
|
|
|
|
|
This function uses the session and credentials appropriate for the current |
|
|
execution environment (ECS or local). |
|
|
|
|
|
Args: |
|
|
service_name (str): The name of the AWS service for which to create a resource. |
|
|
|
|
|
Returns: |
|
|
ServiceResource: A boto3 resource for the specified service. |
|
|
""" |
|
|
logger.info("Creating boto3 resource", service=service_name) |
|
|
session = _get_session() |
|
|
credentials = get_credentials() |
|
|
kwargs: dict[str, Any] = {"region_name": config.get_aws_region(), "use_ssl": True} |
|
|
if credentials: |
|
|
kwargs.update( |
|
|
{ |
|
|
"aws_access_key_id": credentials["AccessKeyId"], |
|
|
"aws_secret_access_key": credentials["SecretAccessKey"], |
|
|
"aws_session_token": credentials["SessionToken"], |
|
|
} |
|
|
) |
|
|
return session.resource(service_name, **kwargs) |
|
|
|
|
|
|
|
|
def get_ssm_client() -> SSMClient: |
|
|
""" |
|
|
Returns an instance of the AWS Systems Manager (SSM) client. |
|
|
|
|
|
This client can be used to interact with the AWS Systems Manager service, |
|
|
such as retrieving parameters from the Parameter Store. |
|
|
|
|
|
Returns: |
|
|
SSMClient: An SSM client configured for the current execution environment. |
|
|
""" |
|
|
return cast(SSMClient, _get_boto3_client("ssm")) |
|
|
|
|
|
|
|
|
def get_s3_client() -> S3Client: |
|
|
""" |
|
|
Returns an instance of the AWS S3 client. |
|
|
This client can be used to interact with Amazon S3 buckets and objects. |
|
|
|
|
|
Returns: |
|
|
S3Client: An S3 client configured for the current execution environment. |
|
|
""" |
|
|
return cast(S3Client, _get_boto3_client("s3")) |
|
|
|
|
|
|
|
|
def get_dynamodb_resource() -> DynamoDBServiceResource: |
|
|
""" |
|
|
Returns an instance of the AWS DynamoDB resource. |
|
|
|
|
|
This resource can be used to interact with DynamoDB tables and items. |
|
|
|
|
|
Returns: |
|
|
DynamoDBServiceResource: A DynamoDB resource configured for the current execution environment. |
|
|
""" |
|
|
return cast(DynamoDBServiceResource, _get_boto3_resource("dynamodb")) |
|
|
|
|
|
|
|
|
def _assume_intermediate_role(role_arn: str, session_name: str) -> dict[str, str]: |
|
|
""" |
|
|
Assumes an IAM role and returns the temporary credentials. |
|
|
|
|
|
Args: |
|
|
role_arn (str): The ARN of the role to assume. |
|
|
session_name (str): An identifier for the assumed role session. |
|
|
|
|
|
Returns: |
|
|
dict[str, str]: The temporary credentials for the assumed role. |
|
|
|
|
|
Raises: |
|
|
ClientError: If there's an error assuming the role. |
|
|
""" |
|
|
logger.info("Attempting to assume role", role_arn=role_arn) |
|
|
try: |
|
|
sts_client = _get_boto3_client("sts") |
|
|
assumed_role = sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)["Credentials"] |
|
|
logger.info("Role assumed successfully") |
|
|
return { |
|
|
"AccessKeyId": assumed_role["AccessKeyId"], |
|
|
"SecretAccessKey": assumed_role["SecretAccessKey"], |
|
|
"SessionToken": assumed_role["SessionToken"], |
|
|
} |
|
|
except ClientError as e: |
|
|
logger.error("Error assuming role", error=str(e)) |
|
|
raise |
|
|
|
|
|
|
|
|
def get_bedrock_client() -> BedrockRuntimeClient: |
|
|
""" |
|
|
Returns a Bedrock client with a specific assumed role session. |
|
|
|
|
|
Returns: |
|
|
BedrockRuntimeClient: A Bedrock client with the assumed role session. |
|
|
""" |
|
|
role_arn = f"arn:aws:iam::{config.get_bedrock_account()}:role/BedrockAccess" |
|
|
assumed_role = _assume_intermediate_role(role_arn, "BedrockAssumeRoleSession") |
|
|
return cast( |
|
|
BedrockRuntimeClient, |
|
|
boto3.client( |
|
|
"bedrock-runtime", |
|
|
aws_access_key_id=assumed_role["AccessKeyId"], |
|
|
aws_secret_access_key=assumed_role["SecretAccessKey"], |
|
|
aws_session_token=assumed_role["SessionToken"], |
|
|
region_name=config.get_aws_region(), |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
class ParameterNotFoundError(Exception): |
|
|
"""Raised when a parameter is not found in the Parameter Store.""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class ParameterStoreAccessError(Exception): |
|
|
"""Raised when there's an error accessing the Parameter Store.""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
def fetch_from_parameter_store(parameter_name: str, is_secret: bool = False) -> str: |
|
|
""" |
|
|
Fetches the value of a parameter from AWS Systems Manager Parameter Store. |
|
|
|
|
|
This function retrieves a parameter value, handling various potential errors. |
|
|
The 'is_secret' parameter is included for backwards compatibility but does not |
|
|
affect the function's behavior, as all parameters are retrieved with decryption. |
|
|
|
|
|
Args: |
|
|
parameter_name (str): The name of the parameter to fetch. |
|
|
is_secret (bool): Whether the parameter is a secret. Defaults to False. |
|
|
|
|
|
Returns: |
|
|
str: The value of the parameter |
|
|
|
|
|
Raises: |
|
|
ParameterNotFoundError: If the parameter is not found. |
|
|
ParameterStoreAccessError: If there's an error accessing the Parameter Store. |
|
|
""" |
|
|
if USE_ENV_VAR_INSTEAD: |
|
|
env_var_name = path_to_env_var(parameter_name) |
|
|
logger.info("Fetching parameter from environment variable", parameter=env_var_name) |
|
|
value = os.environ.get(env_var_name) |
|
|
if value is None: |
|
|
raise ParameterNotFoundError(f"Environment variable '{env_var_name}' not found") |
|
|
return value |
|
|
|
|
|
|
|
|
logger.info("Fetching parameter from Parameter Store", parameter=parameter_name) |
|
|
ssm_client = get_ssm_client() |
|
|
try: |
|
|
response = ssm_client.get_parameter(Name=parameter_name, WithDecryption=is_secret) |
|
|
except ssm_client.exceptions.ParameterNotFound: |
|
|
raise ParameterNotFoundError(f"Parameter '{parameter_name}' not found") |
|
|
except ClientError as e: |
|
|
raise ParameterStoreAccessError(f"Error accessing Parameter Store: {str(e)}") |
|
|
|
|
|
logger.info("Successfully fetched parameter", parameter=parameter_name) |
|
|
return str(response["Parameter"].get("Value", "")) |
|
|
|