This commit is contained in:
Niko Izsak 2025-04-24 00:54:43 -07:00 committed by GitHub
commit 2ee258df3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 222 additions and 37 deletions

View file

@ -19,6 +19,7 @@ LiteLLM supports the following OIDC identity providers:
| CircleCI v2 | `circleci_v2`| No | | CircleCI v2 | `circleci_v2`| No |
| GitHub Actions | `github` | Yes | | GitHub Actions | `github` | Yes |
| Azure Kubernetes Service | `azure` | No | | Azure Kubernetes Service | `azure` | No |
| Azure AD | `azure` | Yes |
| File | `file` | No | | File | `file` | No |
| Environment Variable | `env` | No | | Environment Variable | `env` | No |
| Environment Path | `env_path` | No | | Environment Path | `env_path` | No |
@ -261,3 +262,15 @@ The custom role below is the recommended minimum permissions for the Azure appli
_Note: Your UUIDs will be different._ _Note: Your UUIDs will be different._
Please contact us for paid enterprise support if you need help setting up Azure AD applications. Please contact us for paid enterprise support if you need help setting up Azure AD applications.
### Azure AD -> Amazon Bedrock
```yaml
model list:
- model_name: aws/claude-3-5-sonnet
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
aws_region_name: "eu-central-1"
aws_role_name: "arn:aws:iam::12345678:role/bedrock-role"
aws_web_identity_token: "oidc/azure/api://123-456-789-9d04"
aws_session_name: "litellm-session"
```

View file

@ -1,8 +1,8 @@
import os import os
from typing import Callable from typing import Callable, Optional
def get_azure_ad_token_provider() -> Callable[[], str]: def get_azure_ad_token_provider(azure_scope: Optional[str] = None) -> Callable[[], str]:
""" """
Get Azure AD token provider based on Service Principal with Secret workflow. Get Azure AD token provider based on Service Principal with Secret workflow.
@ -11,15 +11,22 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#service-principal-with-secret; https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#service-principal-with-secret;
https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python. https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python.
Args:
azure_scope (str, optional): The Azure scope to request token for.
Defaults to environment variable AZURE_SCOPE or
"https://cognitiveservices.azure.com/.default".
Returns: Returns:
Callable that returns a temporary authentication token. Callable that returns a temporary authentication token.
""" """
import azure.identity as identity import azure.identity as identity
from azure.identity import get_bearer_token_provider from azure.identity import get_bearer_token_provider
azure_scope = os.environ.get( if azure_scope is None:
"AZURE_SCOPE", "https://cognitiveservices.azure.com/.default" azure_scope = os.environ.get(
) "AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
)
cred = os.environ.get("AZURE_CREDENTIAL", "ClientSecretCredential") cred = os.environ.get("AZURE_CREDENTIAL", "ClientSecretCredential")
cred_cls = getattr(identity, cred) cred_cls = getattr(identity, cred)

View file

@ -12,6 +12,9 @@ from litellm._logging import print_verbose, verbose_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.proxy._types import KeyManagementSystem from litellm.proxy._types import KeyManagementSystem
from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider,
)
oidc_cache = DualCache() oidc_cache = DualCache()
@ -102,6 +105,7 @@ def get_secret( # noqa: PLR0915
if secret_name.startswith("oidc/"): if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "") secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1) oidc_provider, oidc_aud = secret_name_split.split("/", 1)
oidc_aud = "/".join(secret_name_split.split("/")[1:])
# TODO: Add caching for HTTP requests # TODO: Add caching for HTTP requests
if oidc_provider == "google": if oidc_provider == "google":
oidc_token = oidc_cache.get_cache(key=secret_name) oidc_token = oidc_cache.get_cache(key=secret_name)
@ -137,10 +141,7 @@ def get_secret( # noqa: PLR0915
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions # https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL") actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if ( if actions_id_token_request_url is None or actions_id_token_request_token is None:
actions_id_token_request_url is None
or actions_id_token_request_token is None
):
raise ValueError( raise ValueError(
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment" "ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
) )
@ -168,7 +169,19 @@ def get_secret( # noqa: PLR0915
# https://azure.github.io/azure-workload-identity/docs/quick-start.html # https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE") azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None: if azure_federated_token_file is None:
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment") verbose_logger.warning(
"AZURE_FEDERATED_TOKEN_FILE not found in environment will use Azure AD token provider"
)
azure_token_provider = get_azure_ad_token_provider(azure_scope=oidc_aud)
try:
oidc_token = azure_token_provider()
if oidc_token is None:
raise ValueError("Azure OIDC provider returned None token")
return oidc_token
except Exception as e:
error_msg = f"Azure OIDC provider failed: {str(e)}"
verbose_logger.error(error_msg)
raise ValueError(error_msg)
with open(azure_federated_token_file, "r") as f: with open(azure_federated_token_file, "r") as f:
oidc_token = f.read() oidc_token = f.read()
return oidc_token return oidc_token
@ -195,10 +208,7 @@ def get_secret( # noqa: PLR0915
raise ValueError("Unsupported OIDC provider") raise ValueError("Unsupported OIDC provider")
try: try:
if ( if _should_read_secret_from_secret_manager() and litellm.secret_manager_client is not None:
_should_read_secret_from_secret_manager()
and litellm.secret_manager_client is not None
):
try: try:
client = litellm.secret_manager_client client = litellm.secret_manager_client
key_manager = "local" key_manager = "local"
@ -224,9 +234,7 @@ def get_secret( # noqa: PLR0915
): ):
encrypted_secret: Any = os.getenv(secret_name) encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None: if encrypted_secret is None:
raise ValueError( raise ValueError("Google KMS requires the encrypted secret to be in the environment!")
"Google KMS requires the encrypted secret to be in the environment!"
)
b64_flag = _is_base64(encrypted_secret) b64_flag = _is_base64(encrypted_secret)
if b64_flag is True: # if passed in as encoded b64 string if b64_flag is True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret) encrypted_secret = base64.b64decode(encrypted_secret)
@ -241,20 +249,14 @@ def get_secret( # noqa: PLR0915
"ciphertext": ciphertext, "ciphertext": ciphertext,
} }
) )
secret = response.plaintext.decode( secret = response.plaintext.decode("utf-8") # assumes the original value was encoded with utf-8
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value: elif key_manager == KeyManagementSystem.AWS_KMS.value:
""" """
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys. Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
""" """
encrypted_value = os.getenv(secret_name, None) encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None: if encrypted_value is None:
raise Exception( raise Exception("AWS KMS - Encrypted Value of Key={} is None".format(secret_name))
"AWS KMS - Encrypted Value of Key={} is None".format(
secret_name
)
)
# Decode the base64 encoded ciphertext # Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value) ciphertext_blob = base64.b64decode(encrypted_value)
@ -281,14 +283,10 @@ def get_secret( # noqa: PLR0915
print_verbose(f"get_secret_value_response: {secret}") print_verbose(f"get_secret_value_response: {secret}")
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value: elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try: try:
secret = client.get_secret_from_google_secret_manager( secret = client.get_secret_from_google_secret_manager(secret_name)
secret_name
)
print_verbose(f"secret from google secret manager: {secret}") print_verbose(f"secret from google secret manager: {secret}")
if secret is None: if secret is None:
raise ValueError( raise ValueError(f"No secret found in Google Secret Manager for {secret_name}")
f"No secret found in Google Secret Manager for {secret_name}"
)
except Exception as e: except Exception as e:
print_verbose(f"An error occurred - {str(e)}") print_verbose(f"An error occurred - {str(e)}")
raise e raise e
@ -296,9 +294,7 @@ def get_secret( # noqa: PLR0915
try: try:
secret = client.sync_read_secret(secret_name=secret_name) secret = client.sync_read_secret(secret_name=secret_name)
if secret is None: if secret is None:
raise ValueError( raise ValueError(f"No secret found in Hashicorp Secret Manager for {secret_name}")
f"No secret found in Hashicorp Secret Manager for {secret_name}"
)
except Exception as e: except Exception as e:
print_verbose(f"An error occurred - {str(e)}") print_verbose(f"An error occurred - {str(e)}")
raise e raise e
@ -323,9 +319,7 @@ def get_secret( # noqa: PLR0915
else: else:
secret = os.environ.get(secret_name) secret = os.environ.get(secret_name)
secret_value_as_bool = str_to_bool(secret) if secret is not None else None secret_value_as_bool = str_to_bool(secret) if secret is not None else None
if secret_value_as_bool is not None and isinstance( if secret_value_as_bool is not None and isinstance(secret_value_as_bool, bool):
secret_value_as_bool, bool
):
return secret_value_as_bool return secret_value_as_bool
else: else:
return secret return secret

View file

@ -0,0 +1,171 @@
import pytest
import os
from unittest.mock import Mock, patch
from litellm.secret_managers.main import get_secret
import logging
# Set up logging for debugging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Mock HTTPHandler and oidc_cache
class MockHTTPHandler:
def __init__(self, timeout):
self.timeout = timeout
self.status_code = 200
self.text = "mocked_token"
self.json_data = {"value": "mocked_token"}
def get(self, url, params=None, headers=None):
# Store params for audience verification
self.last_params = params
logger.debug(f"MockHTTPHandler.get called with url={url}, params={params}, headers={headers}")
mock_response = Mock()
mock_response.status_code = self.status_code
mock_response.text = self.text
mock_response.json.return_value = self.json_data
return mock_response
@pytest.fixture
def mock_oidc_cache():
cache = Mock()
cache.get_cache.return_value = None
cache.set_cache = Mock()
return cache
@pytest.fixture
def mock_env():
with patch.dict(os.environ, {}, clear=True):
yield os.environ
@patch('litellm.secret_managers.main.oidc_cache')
@patch('litellm.secret_managers.main.HTTPHandler')
def test_oidc_google_success(mock_http_handler, mock_oidc_cache):
mock_oidc_cache.get_cache.return_value = None
mock_handler = MockHTTPHandler(timeout=600.0)
mock_http_handler.return_value = mock_handler
secret_name = "oidc/google/[invalid url, do not cite]"
result = get_secret(secret_name)
assert result == "mocked_token"
assert mock_handler.last_params == {"audience": "[invalid url, do not cite]"}
mock_oidc_cache.set_cache.assert_called_once_with(
key=secret_name, value="mocked_token", ttl=3540
)
@patch('litellm.secret_managers.main.oidc_cache')
def test_oidc_google_cached(mock_oidc_cache):
mock_oidc_cache.get_cache.return_value = "cached_token"
secret_name = "oidc/google/[invalid url, do not cite]"
with patch('litellm.HTTPHandler') as mock_http:
result = get_secret(secret_name)
assert result == "cached_token", f"Expected cached token, got {result}"
mock_oidc_cache.get_cache.assert_called_with(key=secret_name)
mock_http.assert_not_called()
def test_oidc_google_failure(mock_oidc_cache):
mock_handler = MockHTTPHandler(timeout=600.0)
mock_handler.status_code = 400
with patch('litellm.secret_managers.main.HTTPHandler', return_value=mock_handler):
mock_oidc_cache.get_cache.return_value = None
secret_name = "oidc/google/https://example.com/api"
with pytest.raises(ValueError, match="Google OIDC provider failed"):
get_secret(secret_name)
def test_oidc_circleci_success(mock_env):
mock_env["CIRCLE_OIDC_TOKEN"] = "circleci_token"
secret_name = "oidc/circleci/test-audience"
result = get_secret(secret_name)
assert result == "circleci_token"
def test_oidc_circleci_failure():
secret_name = "oidc/circleci/test-audience"
with pytest.raises(ValueError, match="CIRCLE_OIDC_TOKEN not found in environment"):
get_secret(secret_name)
@patch('litellm.secret_managers.main.oidc_cache')
@patch('litellm.secret_managers.main.HTTPHandler')
def test_oidc_github_success(mock_http_handler, mock_oidc_cache, mock_env):
mock_env["ACTIONS_ID_TOKEN_REQUEST_URL"] = "https://github.com/token"
mock_env["ACTIONS_ID_TOKEN_REQUEST_TOKEN"] = "github_token"
mock_oidc_cache.get_cache.return_value = None
mock_handler = MockHTTPHandler(timeout=600.0)
mock_http_handler.return_value = mock_handler
secret_name = "oidc/github/github-audience"
result = get_secret(secret_name)
assert result == "mocked_token", f"Expected token 'mocked_token', got {result}"
assert mock_handler.last_params == {"audience": "github-audience"}
logger.debug(f"set_cache call args: {mock_oidc_cache.set_cache.call_args}")
mock_oidc_cache.set_cache.assert_called_once()
mock_oidc_cache.set_cache.assert_called_with(
key=secret_name, value="mocked_token", ttl=295
)
def test_oidc_github_missing_env():
secret_name = "oidc/github/github-audience"
with pytest.raises(ValueError, match="ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"):
get_secret(secret_name)
def test_oidc_azure_file_success(mock_env, tmp_path):
token_file = tmp_path / "token.txt"
token_file.write_text("azure_token")
mock_env["AZURE_FEDERATED_TOKEN_FILE"] = str(token_file)
secret_name = "oidc/azure/azure-audience"
result = get_secret(secret_name)
assert result == "azure_token"
@patch('litellm.secret_managers.main.get_azure_ad_token_provider')
def test_oidc_azure_ad_token_success(mock_get_azure_ad_token_provider):
mock_token_provider = Mock(return_value="azure_ad_token")
mock_get_azure_ad_token_provider.return_value = mock_token_provider
secret_name = "oidc/azure/api://azure-audience"
result = get_secret(secret_name)
assert result == "azure_ad_token"
mock_get_azure_ad_token_provider.assert_called_once_with(azure_scope="api://azure-audience")
mock_token_provider.assert_called_once_with()
def test_oidc_file_success(tmp_path):
token_file = tmp_path / "token.txt"
token_file.write_text("file_token")
secret_name = f"oidc/file/{token_file}"
result = get_secret(secret_name)
assert result == "file_token"
def test_oidc_env_success(mock_env):
mock_env["CUSTOM_TOKEN"] = "env_token"
secret_name = "oidc/env/CUSTOM_TOKEN"
result = get_secret(secret_name)
assert result == "env_token"
def test_oidc_env_path_success(mock_env, tmp_path):
token_file = tmp_path / "token.txt"
token_file.write_text("env_path_token")
mock_env["TOKEN_PATH"] = str(token_file)
secret_name = "oidc/env_path/TOKEN_PATH"
result = get_secret(secret_name)
assert result == "env_path_token"
def test_unsupported_oidc_provider():
secret_name = "oidc/unsupported/unsupported-audience"
with pytest.raises(ValueError, match="Unsupported OIDC provider"):
get_secret(secret_name)