mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
add: Allow auth extra_headers to be set via .env
- Allow extra_headers to be set via .env - Add env to docs
This commit is contained in:
parent
f5996b2f6b
commit
379f98901f
3 changed files with 38 additions and 0 deletions
|
@ -322,6 +322,7 @@ router_settings:
|
|||
| AWS_SECRET_ACCESS_KEY | Secret Access Key for AWS services
|
||||
| AWS_SESSION_NAME | Name for AWS session
|
||||
| AWS_WEB_IDENTITY_TOKEN | Web identity token for AWS
|
||||
| AWS_EXTRA_HEADERS_AUTH_BEARER_TOKEN | Bearer token for custom API endpoints
|
||||
| AZURE_API_VERSION | Version of the Azure API being used
|
||||
| AZURE_AUTHORITY_HOST | Azure authority host URL
|
||||
| AZURE_CLIENT_ID | Client ID for Azure services
|
||||
|
|
|
@ -56,6 +56,7 @@ class BaseAWSLLM:
|
|||
"aws_sts_endpoint",
|
||||
"aws_bedrock_runtime_endpoint",
|
||||
]
|
||||
self.env_extra_headers = {}
|
||||
|
||||
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:
|
||||
"""
|
||||
|
@ -200,6 +201,11 @@ class BaseAWSLLM:
|
|||
else:
|
||||
credentials, _cache_ttl = self._auth_with_env_vars()
|
||||
|
||||
env_aws_extra_headers_auth_bearer_token = get_secret("AWS_EXTRA_HEADERS_AUTH_BEARER_TOKEN")
|
||||
if env_aws_extra_headers_auth_bearer_token:
|
||||
verbose_logger.debug("Using Bearer token form env")
|
||||
self.env_extra_headers['Authorization'] = f"Bearer {env_aws_extra_headers_auth_bearer_token}"
|
||||
|
||||
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
|
||||
return credentials
|
||||
|
||||
|
@ -618,6 +624,12 @@ class BaseAWSLLM:
|
|||
method="POST", url=endpoint_url, data=data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
|
||||
if extra_headers is None and len(self.env_extra_headers) > 0:
|
||||
extra_headers = self.env_extra_headers
|
||||
elif extra_headers is not None:
|
||||
extra_headers.update(self.env_extra_headers)
|
||||
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
|
|
|
@ -98,3 +98,28 @@ def test_auth_functions_tracer_wrapping():
|
|||
assert (
|
||||
has_tracer_wrap
|
||||
), f"Auth function on line {line_number} is not wrapped with @tracer.wrap: {line.strip()}"
|
||||
|
||||
|
||||
def test_loading_bearer_token_from_env():
|
||||
aws_region_name = "us"
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
data = ""
|
||||
headers = {}
|
||||
|
||||
os.environ['AWS_SECRET_ACCESS_KEY'] = "fake_aws_secret_access_key"
|
||||
os.environ['AWS_ACCESS_KEY_ID'] = "fake_aws_access_key_id"
|
||||
|
||||
extra_headers = {}
|
||||
base_aws_llm = BaseAWSLLM()
|
||||
credentials = base_aws_llm.get_credentials()
|
||||
aws_prep_req = base_aws_llm.get_request_headers(credentials, aws_region_name, extra_headers, endpoint_url, data, headers)
|
||||
assert 'Authorization' in aws_prep_req.headers.keys()
|
||||
assert 'Bearer' not in aws_prep_req.headers["Authorization"]
|
||||
|
||||
os.environ['AWS_EXTRA_HEADERS_AUTH_BEARER_TOKEN'] = "fake_token"
|
||||
extra_headers = {}
|
||||
base_aws_llm = BaseAWSLLM()
|
||||
credentials = base_aws_llm.get_credentials()
|
||||
aws_prep_req = base_aws_llm.get_request_headers(credentials, aws_region_name, extra_headers, endpoint_url, data, headers)
|
||||
assert 'Authorization' in aws_prep_req.headers.keys()
|
||||
assert aws_prep_req.headers['Authorization'] == "Bearer fake_token"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue