Updated more references to AWS session token

This commit is contained in:
Brian Schultheiss 2024-06-23 13:37:38 -07:00
parent 7f91e53548
commit 3fbb25f903
7 changed files with 25 additions and 6 deletions

View file

@ -162,9 +162,10 @@ def completion(
):
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
model_id = optional_params.pop("model_id", None)
@ -175,6 +176,7 @@ def completion(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
)
else:
@ -249,6 +251,7 @@ def completion(
model_id=model_id,
aws_secret_access_key=aws_secret_access_key,
aws_access_key_id=aws_access_key_id,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
)
return response
@ -281,6 +284,7 @@ def completion(
model_id=model_id,
aws_secret_access_key=aws_secret_access_key,
aws_access_key_id=aws_access_key_id,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
)
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
@ -414,6 +418,7 @@ async def async_streaming(
aws_secret_access_key: Optional[str],
aws_access_key_id: Optional[str],
aws_region_name: Optional[str],
aws_session_token: Optional[str] = None,
):
"""
Use aioboto3
@ -429,6 +434,7 @@ async def async_streaming(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
)
else:
@ -481,6 +487,7 @@ async def async_completion(
aws_secret_access_key: Optional[str],
aws_access_key_id: Optional[str],
aws_region_name: Optional[str],
aws_session_token: Optional[str] = None,
):
"""
Use aioboto3
@ -496,6 +503,7 @@ async def async_completion(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
)
else:
@ -639,9 +647,10 @@ def embedding(
### BOTO3 INIT
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
if aws_access_key_id is not None:
@ -651,6 +660,7 @@ def embedding(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
)
else: