diff --git a/docs/my-website/docs/providers/aws_sagemaker.md b/docs/my-website/docs/providers/aws_sagemaker.md index 2b65709e8e..5793fb05ae 100644 --- a/docs/my-website/docs/providers/aws_sagemaker.md +++ b/docs/my-website/docs/providers/aws_sagemaker.md @@ -59,6 +59,7 @@ response = completion( messages=[{ "content": "Hello, how are you?","role": "user"}], aws_access_key_id="", aws_secret_access_key="", + aws_session_token="", aws_region_name="", ) ``` diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index adbc54caa7..7f9b21b96b 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -538,6 +538,7 @@ response = completion( aws_region_name=aws_region_name, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, aws_role_name=aws_role_name, aws_session_name="my-test-session", ) @@ -553,7 +554,7 @@ This is a deprecated flow. Boto3 is not async. And boto3.client does not let us Experimental - 2024-Jun-23: aws_access_key_id, aws_secret_access_key=, and aws_session_token will be extracted from boto3.client and be passed onto the httpx client - + ::: Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth. diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 6c941bb558..2403edf814 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -576,6 +576,7 @@ def init_bedrock_client( params_to_check = [ aws_access_key_id, aws_secret_access_key, + aws_session_token, aws_region_name, aws_bedrock_runtime_endpoint, aws_session_name, @@ -1344,7 +1345,7 @@ def embedding( encoding=None, ): ### BOTO3 INIT ### - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_session_token = optional_params.pop("aws_session_token", None) @@ -1437,7 +1438,7 @@ def image_generation( Bedrock Image Gen endpoint support """ ### BOTO3 INIT ### - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_session_token = optional_params.pop("aws_session_token", None) diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 3eb8acd390..d00695f870 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -745,7 +745,7 @@ class BedrockLLM(BaseLLM): provider = model.split(".")[0] ## CREDENTIALS ## - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_session_token = optional_params.pop("aws_session_token", None) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 8e75428bb7..7d639b7bb2 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -162,9 +162,10 @@ def completion( ): import boto3 - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) aws_region_name = optional_params.pop("aws_region_name", None) model_id = optional_params.pop("model_id", None) @@ -175,6 +176,7 @@ def completion( service_name="sagemaker-runtime", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, region_name=aws_region_name, ) else: @@ -249,6 +251,7 @@ def completion( model_id=model_id, aws_secret_access_key=aws_secret_access_key, aws_access_key_id=aws_access_key_id, + aws_session_token=aws_session_token, aws_region_name=aws_region_name, ) return response @@ -281,6 +284,7 @@ def completion( model_id=model_id, aws_secret_access_key=aws_secret_access_key, aws_access_key_id=aws_access_key_id, + aws_session_token=aws_session_token, aws_region_name=aws_region_name, ) data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode( @@ -414,6 +418,7 @@ async def async_streaming( aws_secret_access_key: Optional[str], aws_access_key_id: Optional[str], aws_region_name: Optional[str], + aws_session_token: Optional[str] = None, ): """ Use aioboto3 @@ -429,6 +434,7 @@ async def async_streaming( service_name="sagemaker-runtime", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, region_name=aws_region_name, ) else: @@ -481,6 +487,7 @@ async def async_completion( aws_secret_access_key: Optional[str], aws_access_key_id: Optional[str], aws_region_name: Optional[str], + aws_session_token: Optional[str] = None, ): """ Use aioboto3 @@ -496,6 +503,7 @@ async def async_completion( service_name="sagemaker-runtime", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, region_name=aws_region_name, ) else: @@ -639,9 +647,10 @@ def embedding( ### BOTO3 INIT import boto3 - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) aws_region_name = optional_params.pop("aws_region_name", None) if aws_access_key_id is not None: @@ -651,6 +660,7 @@ def embedding( service_name="sagemaker-runtime", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, region_name=aws_region_name, ) else: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 30b90abe64..9c1039f51e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -6343,6 +6343,7 @@ async def model_info_v2( _model["litellm_params"].pop("vertex_credentials", None) _model["litellm_params"].pop("aws_access_key_id", None) _model["litellm_params"].pop("aws_secret_access_key", None) + _model["litellm_params"].pop("aws_session_token", None) verbose_proxy_logger.debug("all_models: %s", all_models) return {"data": all_models} @@ -6859,6 +6860,7 @@ async def model_info_v1( model["litellm_params"].pop("vertex_credentials", None) model["litellm_params"].pop("aws_access_key_id", None) model["litellm_params"].pop("aws_secret_access_key", None) + model["litellm_params"].pop("aws_session_token", None) verbose_proxy_logger.debug("all_models: %s", all_models) return {"data": all_models} diff --git a/litellm/types/router.py b/litellm/types/router.py index e6864ffe2e..059a8620e5 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -145,6 +145,7 @@ class GenericLiteLLMParams(BaseModel): ## AWS BEDROCK / SAGEMAKER ## aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None + aws_session_token: Optional[str] = None aws_region_name: Optional[str] = None ## IBM WATSONX ## watsonx_region_name: Optional[str] = None @@ -178,6 +179,7 @@ class GenericLiteLLMParams(BaseModel): ## AWS BEDROCK / SAGEMAKER ## aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, aws_region_name: Optional[str] = None, ## IBM WATSONX ## watsonx_region_name: Optional[str] = None, @@ -242,6 +244,7 @@ class LiteLLM_Params(GenericLiteLLMParams): ## AWS BEDROCK / SAGEMAKER ## aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, aws_region_name: Optional[str] = None, **params, ): @@ -307,6 +310,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): ## AWS BEDROCK / SAGEMAKER ## aws_access_key_id: Optional[str] aws_secret_access_key: Optional[str] + aws_session_token: Optional[str] aws_region_name: Optional[str] ## IBM WATSONX ## watsonx_region_name: Optional[str]