From 379f98901fd75e1b83f3eb292b86e5f65ad260ad Mon Sep 17 00:00:00 2001 From: Gregor Karetka Date: Wed, 5 Mar 2025 14:55:57 +0100 Subject: [PATCH] add: Allow auth extra_headers to be set via .env - Allow extra_headers to be set via .env - Add env to docs --- docs/my-website/docs/proxy/config_settings.md | 1 + litellm/llms/bedrock/base_aws_llm.py | 12 +++++++++ .../litellm/llms/bedrock/test_base_aws_llm.py | 25 +++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 1e3c800b03..3412abb46a 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -322,6 +322,7 @@ router_settings: | AWS_SECRET_ACCESS_KEY | Secret Access Key for AWS services | AWS_SESSION_NAME | Name for AWS session | AWS_WEB_IDENTITY_TOKEN | Web identity token for AWS +| AWS_EXTRA_HEADERS_AUTH_BEARER_TOKEN | Bearer token for custom API endpoints | AZURE_API_VERSION | Version of the Azure API being used | AZURE_AUTHORITY_HOST | Azure authority host URL | AZURE_CLIENT_ID | Client ID for Azure services diff --git a/litellm/llms/bedrock/base_aws_llm.py b/litellm/llms/bedrock/base_aws_llm.py index 133ef6a952..a5eea9fee6 100644 --- a/litellm/llms/bedrock/base_aws_llm.py +++ b/litellm/llms/bedrock/base_aws_llm.py @@ -56,6 +56,7 @@ class BaseAWSLLM: "aws_sts_endpoint", "aws_bedrock_runtime_endpoint", ] + self.env_extra_headers = {} def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str: """ @@ -200,6 +201,11 @@ class BaseAWSLLM: else: credentials, _cache_ttl = self._auth_with_env_vars() + env_aws_extra_headers_auth_bearer_token = get_secret("AWS_EXTRA_HEADERS_AUTH_BEARER_TOKEN") + if env_aws_extra_headers_auth_bearer_token: + verbose_logger.debug("Using Bearer token form env") + self.env_extra_headers['Authorization'] = f"Bearer {env_aws_extra_headers_auth_bearer_token}" + self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl) return credentials @@ -618,6 +624,12 @@ class BaseAWSLLM: method="POST", url=endpoint_url, data=data, headers=headers ) sigv4.add_auth(request) + + if extra_headers is None and len(self.env_extra_headers) > 0: + extra_headers = self.env_extra_headers + elif extra_headers is not None: + extra_headers.update(self.env_extra_headers) + if ( extra_headers is not None and "Authorization" in extra_headers ): # prevent sigv4 from overwriting the auth header diff --git a/tests/litellm/llms/bedrock/test_base_aws_llm.py b/tests/litellm/llms/bedrock/test_base_aws_llm.py index 3a2f691c1d..3dc10f40e9 100644 --- a/tests/litellm/llms/bedrock/test_base_aws_llm.py +++ b/tests/litellm/llms/bedrock/test_base_aws_llm.py @@ -98,3 +98,28 @@ def test_auth_functions_tracer_wrapping(): assert ( has_tracer_wrap ), f"Auth function on line {line_number} is not wrapped with @tracer.wrap: {line.strip()}" + + +def test_loading_bearer_token_from_env(): + aws_region_name = "us" + endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + data = "" + headers = {} + + os.environ['AWS_SECRET_ACCESS_KEY'] = "fake_aws_secret_access_key" + os.environ['AWS_ACCESS_KEY_ID'] = "fake_aws_access_key_id" + + extra_headers = {} + base_aws_llm = BaseAWSLLM() + credentials = base_aws_llm.get_credentials() + aws_prep_req = base_aws_llm.get_request_headers(credentials, aws_region_name, extra_headers, endpoint_url, data, headers) + assert 'Authorization' in aws_prep_req.headers.keys() + assert 'Bearer' not in aws_prep_req.headers["Authorization"] + + os.environ['AWS_EXTRA_HEADERS_AUTH_BEARER_TOKEN'] = "fake_token" + extra_headers = {} + base_aws_llm = BaseAWSLLM() + credentials = base_aws_llm.get_credentials() + aws_prep_req = base_aws_llm.get_request_headers(credentials, aws_region_name, extra_headers, endpoint_url, data, headers) + assert 'Authorization' in aws_prep_req.headers.keys() + assert aws_prep_req.headers['Authorization'] == "Bearer fake_token"