| import os | |
| from traceback import print_exc | |
| import boto3 | |
| from handler import ContentHandler | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| endpoint_name = os.environ.get("AWS_ENDPOINT_NAME") | |
| aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") | |
| aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") | |
| aws_region_name = os.environ.get("AWS_REGION_NAME") | |
| boto_client = boto3.client( | |
| service_name='sagemaker-runtime', | |
| aws_access_key_id=aws_access_key_id, | |
| aws_secret_access_key=aws_secret_access_key, | |
| region_name=aws_region_name) | |
| content_handler = ContentHandler() | |
| def invoke_endpoint( | |
| input_, | |
| model_parameters, | |
| ): | |
| try: | |
| response = boto_client.invoke_endpoint( | |
| EndpointName=endpoint_name, | |
| ContentType='application/json', | |
| Body=content_handler.transform_input(prompt=input_, model_kwargs=model_parameters) | |
| ) | |
| return content_handler.transform_output(response['Body']) | |
| except: | |
| print_exc() | |
| return None |