add: Allow auth extra_headers to be set via .env

- Allow extra_headers to be set via .env
- Add env to docs
This commit is contained in:
Gregor Karetka 2025-03-05 14:55:57 +01:00
parent f5996b2f6b
commit 379f98901f
3 changed files with 38 additions and 0 deletions

View file

@ -322,6 +322,7 @@ router_settings:
| AWS_SECRET_ACCESS_KEY | Secret Access Key for AWS services
| AWS_SESSION_NAME | Name for AWS session
| AWS_WEB_IDENTITY_TOKEN | Web identity token for AWS
| AWS_EXTRA_HEADERS_AUTH_BEARER_TOKEN | Bearer token for custom API endpoints
| AZURE_API_VERSION | Version of the Azure API being used
| AZURE_AUTHORITY_HOST | Azure authority host URL
| AZURE_CLIENT_ID | Client ID for Azure services

View file

@ -56,6 +56,7 @@ class BaseAWSLLM:
"aws_sts_endpoint",
"aws_bedrock_runtime_endpoint",
]
self.env_extra_headers = {}
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:
"""
@ -200,6 +201,11 @@ class BaseAWSLLM:
else:
credentials, _cache_ttl = self._auth_with_env_vars()
env_aws_extra_headers_auth_bearer_token = get_secret("AWS_EXTRA_HEADERS_AUTH_BEARER_TOKEN")
if env_aws_extra_headers_auth_bearer_token:
verbose_logger.debug("Using Bearer token form env")
self.env_extra_headers['Authorization'] = f"Bearer {env_aws_extra_headers_auth_bearer_token}"
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
return credentials
@ -618,6 +624,12 @@ class BaseAWSLLM:
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
if extra_headers is None and len(self.env_extra_headers) > 0:
extra_headers = self.env_extra_headers
elif extra_headers is not None:
extra_headers.update(self.env_extra_headers)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header

View file

@ -98,3 +98,28 @@ def test_auth_functions_tracer_wrapping():
assert (
has_tracer_wrap
), f"Auth function on line {line_number} is not wrapped with @tracer.wrap: {line.strip()}"
def test_loading_bearer_token_from_env():
aws_region_name = "us"
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
data = ""
headers = {}
os.environ['AWS_SECRET_ACCESS_KEY'] = "fake_aws_secret_access_key"
os.environ['AWS_ACCESS_KEY_ID'] = "fake_aws_access_key_id"
extra_headers = {}
base_aws_llm = BaseAWSLLM()
credentials = base_aws_llm.get_credentials()
aws_prep_req = base_aws_llm.get_request_headers(credentials, aws_region_name, extra_headers, endpoint_url, data, headers)
assert 'Authorization' in aws_prep_req.headers.keys()
assert 'Bearer' not in aws_prep_req.headers["Authorization"]
os.environ['AWS_EXTRA_HEADERS_AUTH_BEARER_TOKEN'] = "fake_token"
extra_headers = {}
base_aws_llm = BaseAWSLLM()
credentials = base_aws_llm.get_credentials()
aws_prep_req = base_aws_llm.get_request_headers(credentials, aws_region_name, extra_headers, endpoint_url, data, headers)
assert 'Authorization' in aws_prep_req.headers.keys()
assert aws_prep_req.headers['Authorization'] == "Bearer fake_token"