Updated more references to AWS session token

This commit is contained in:
Brian Schultheiss 2024-06-23 13:37:38 -07:00
parent 7f91e53548
commit 3fbb25f903
7 changed files with 25 additions and 6 deletions

View file

@ -59,6 +59,7 @@ response = completion(
messages=[{ "content": "Hello, how are you?","role": "user"}], messages=[{ "content": "Hello, how are you?","role": "user"}],
aws_access_key_id="", aws_access_key_id="",
aws_secret_access_key="", aws_secret_access_key="",
aws_session_token="",
aws_region_name="", aws_region_name="",
) )
``` ```

View file

@ -538,6 +538,7 @@ response = completion(
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name="my-test-session", aws_session_name="my-test-session",
) )

View file

@ -576,6 +576,7 @@ def init_bedrock_client(
params_to_check = [ params_to_check = [
aws_access_key_id, aws_access_key_id,
aws_secret_access_key, aws_secret_access_key,
aws_session_token,
aws_region_name, aws_region_name,
aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint,
aws_session_name, aws_session_name,
@ -1344,7 +1345,7 @@ def embedding(
encoding=None, encoding=None,
): ):
### BOTO3 INIT ### ### BOTO3 INIT ###
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None) aws_session_token = optional_params.pop("aws_session_token", None)
@ -1437,7 +1438,7 @@ def image_generation(
Bedrock Image Gen endpoint support Bedrock Image Gen endpoint support
""" """
### BOTO3 INIT ### ### BOTO3 INIT ###
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None) aws_session_token = optional_params.pop("aws_session_token", None)

View file

@ -745,7 +745,7 @@ class BedrockLLM(BaseLLM):
provider = model.split(".")[0] provider = model.split(".")[0]
## CREDENTIALS ## ## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None) aws_session_token = optional_params.pop("aws_session_token", None)

View file

@ -162,9 +162,10 @@ def completion(
): ):
import boto3 import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None) aws_region_name = optional_params.pop("aws_region_name", None)
model_id = optional_params.pop("model_id", None) model_id = optional_params.pop("model_id", None)
@ -175,6 +176,7 @@ def completion(
service_name="sagemaker-runtime", service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name, region_name=aws_region_name,
) )
else: else:
@ -249,6 +251,7 @@ def completion(
model_id=model_id, model_id=model_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
) )
return response return response
@ -281,6 +284,7 @@ def completion(
model_id=model_id, model_id=model_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
) )
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode( data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
@ -414,6 +418,7 @@ async def async_streaming(
aws_secret_access_key: Optional[str], aws_secret_access_key: Optional[str],
aws_access_key_id: Optional[str], aws_access_key_id: Optional[str],
aws_region_name: Optional[str], aws_region_name: Optional[str],
aws_session_token: Optional[str] = None,
): ):
""" """
Use aioboto3 Use aioboto3
@ -429,6 +434,7 @@ async def async_streaming(
service_name="sagemaker-runtime", service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name, region_name=aws_region_name,
) )
else: else:
@ -481,6 +487,7 @@ async def async_completion(
aws_secret_access_key: Optional[str], aws_secret_access_key: Optional[str],
aws_access_key_id: Optional[str], aws_access_key_id: Optional[str],
aws_region_name: Optional[str], aws_region_name: Optional[str],
aws_session_token: Optional[str] = None,
): ):
""" """
Use aioboto3 Use aioboto3
@ -496,6 +503,7 @@ async def async_completion(
service_name="sagemaker-runtime", service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name, region_name=aws_region_name,
) )
else: else:
@ -639,9 +647,10 @@ def embedding(
### BOTO3 INIT ### BOTO3 INIT
import boto3 import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None) aws_region_name = optional_params.pop("aws_region_name", None)
if aws_access_key_id is not None: if aws_access_key_id is not None:
@ -651,6 +660,7 @@ def embedding(
service_name="sagemaker-runtime", service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name, region_name=aws_region_name,
) )
else: else:

View file

@ -6343,6 +6343,7 @@ async def model_info_v2(
_model["litellm_params"].pop("vertex_credentials", None) _model["litellm_params"].pop("vertex_credentials", None)
_model["litellm_params"].pop("aws_access_key_id", None) _model["litellm_params"].pop("aws_access_key_id", None)
_model["litellm_params"].pop("aws_secret_access_key", None) _model["litellm_params"].pop("aws_secret_access_key", None)
_model["litellm_params"].pop("aws_session_token", None)
verbose_proxy_logger.debug("all_models: %s", all_models) verbose_proxy_logger.debug("all_models: %s", all_models)
return {"data": all_models} return {"data": all_models}
@ -6859,6 +6860,7 @@ async def model_info_v1(
model["litellm_params"].pop("vertex_credentials", None) model["litellm_params"].pop("vertex_credentials", None)
model["litellm_params"].pop("aws_access_key_id", None) model["litellm_params"].pop("aws_access_key_id", None)
model["litellm_params"].pop("aws_secret_access_key", None) model["litellm_params"].pop("aws_secret_access_key", None)
model["litellm_params"].pop("aws_session_token", None)
verbose_proxy_logger.debug("all_models: %s", all_models) verbose_proxy_logger.debug("all_models: %s", all_models)
return {"data": all_models} return {"data": all_models}

View file

@ -145,6 +145,7 @@ class GenericLiteLLMParams(BaseModel):
## AWS BEDROCK / SAGEMAKER ## ## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str] = None aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None aws_secret_access_key: Optional[str] = None
aws_session_token: Optional[str] = None
aws_region_name: Optional[str] = None aws_region_name: Optional[str] = None
## IBM WATSONX ## ## IBM WATSONX ##
watsonx_region_name: Optional[str] = None watsonx_region_name: Optional[str] = None
@ -178,6 +179,7 @@ class GenericLiteLLMParams(BaseModel):
## AWS BEDROCK / SAGEMAKER ## ## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None, aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None, aws_region_name: Optional[str] = None,
## IBM WATSONX ## ## IBM WATSONX ##
watsonx_region_name: Optional[str] = None, watsonx_region_name: Optional[str] = None,
@ -242,6 +244,7 @@ class LiteLLM_Params(GenericLiteLLMParams):
## AWS BEDROCK / SAGEMAKER ## ## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None, aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None, aws_region_name: Optional[str] = None,
**params, **params,
): ):
@ -307,6 +310,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
## AWS BEDROCK / SAGEMAKER ## ## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str] aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str] aws_secret_access_key: Optional[str]
aws_session_token: Optional[str]
aws_region_name: Optional[str] aws_region_name: Optional[str]
## IBM WATSONX ## ## IBM WATSONX ##
watsonx_region_name: Optional[str] watsonx_region_name: Optional[str]