This commit is contained in:
Niko Izsak 2025-04-24 00:54:43 -07:00 committed by GitHub
commit 2ee258df3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 222 additions and 37 deletions

View file

@ -19,6 +19,7 @@ LiteLLM supports the following OIDC identity providers:
| CircleCI v2 | `circleci_v2`| No |
| GitHub Actions | `github` | Yes |
| Azure Kubernetes Service | `azure` | No |
| Azure AD | `azure` | Yes |
| File | `file` | No |
| Environment Variable | `env` | No |
| Environment Path | `env_path` | No |
@ -261,3 +262,15 @@ The custom role below is the recommended minimum permissions for the Azure appli
_Note: Your UUIDs will be different._
Please contact us for paid enterprise support if you need help setting up Azure AD applications.
### Azure AD -> Amazon Bedrock
```yaml
model list:
- model_name: aws/claude-3-5-sonnet
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
aws_region_name: "eu-central-1"
aws_role_name: "arn:aws:iam::12345678:role/bedrock-role"
aws_web_identity_token: "oidc/azure/api://123-456-789-9d04"
aws_session_name: "litellm-session"
```

View file

@ -1,8 +1,8 @@
import os
from typing import Callable
from typing import Callable, Optional
def get_azure_ad_token_provider() -> Callable[[], str]:
def get_azure_ad_token_provider(azure_scope: Optional[str] = None) -> Callable[[], str]:
"""
Get Azure AD token provider based on Service Principal with Secret workflow.
@ -11,15 +11,22 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#service-principal-with-secret;
https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python.
Args:
azure_scope (str, optional): The Azure scope to request token for.
Defaults to environment variable AZURE_SCOPE or
"https://cognitiveservices.azure.com/.default".
Returns:
Callable that returns a temporary authentication token.
"""
import azure.identity as identity
from azure.identity import get_bearer_token_provider
azure_scope = os.environ.get(
"AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
)
if azure_scope is None:
azure_scope = os.environ.get(
"AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
)
cred = os.environ.get("AZURE_CREDENTIAL", "ClientSecretCredential")
cred_cls = getattr(identity, cred)

View file

@ -12,6 +12,9 @@ from litellm._logging import print_verbose, verbose_logger
from litellm.caching.caching import DualCache
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.proxy._types import KeyManagementSystem
from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider,
)
oidc_cache = DualCache()
@ -102,6 +105,7 @@ def get_secret( # noqa: PLR0915
if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
oidc_aud = "/".join(secret_name_split.split("/")[1:])
# TODO: Add caching for HTTP requests
if oidc_provider == "google":
oidc_token = oidc_cache.get_cache(key=secret_name)
@ -137,10 +141,7 @@ def get_secret( # noqa: PLR0915
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if (
actions_id_token_request_url is None
or actions_id_token_request_token is None
):
if actions_id_token_request_url is None or actions_id_token_request_token is None:
raise ValueError(
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
)
@ -168,7 +169,19 @@ def get_secret( # noqa: PLR0915
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None:
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
verbose_logger.warning(
"AZURE_FEDERATED_TOKEN_FILE not found in environment will use Azure AD token provider"
)
azure_token_provider = get_azure_ad_token_provider(azure_scope=oidc_aud)
try:
oidc_token = azure_token_provider()
if oidc_token is None:
raise ValueError("Azure OIDC provider returned None token")
return oidc_token
except Exception as e:
error_msg = f"Azure OIDC provider failed: {str(e)}"
verbose_logger.error(error_msg)
raise ValueError(error_msg)
with open(azure_federated_token_file, "r") as f:
oidc_token = f.read()
return oidc_token
@ -195,10 +208,7 @@ def get_secret( # noqa: PLR0915
raise ValueError("Unsupported OIDC provider")
try:
if (
_should_read_secret_from_secret_manager()
and litellm.secret_manager_client is not None
):
if _should_read_secret_from_secret_manager() and litellm.secret_manager_client is not None:
try:
client = litellm.secret_manager_client
key_manager = "local"
@ -224,9 +234,7 @@ def get_secret( # noqa: PLR0915
):
encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None:
raise ValueError(
"Google KMS requires the encrypted secret to be in the environment!"
)
raise ValueError("Google KMS requires the encrypted secret to be in the environment!")
b64_flag = _is_base64(encrypted_secret)
if b64_flag is True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret)
@ -241,20 +249,14 @@ def get_secret( # noqa: PLR0915
"ciphertext": ciphertext,
}
)
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
secret = response.plaintext.decode("utf-8") # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception(
"AWS KMS - Encrypted Value of Key={} is None".format(
secret_name
)
)
raise Exception("AWS KMS - Encrypted Value of Key={} is None".format(secret_name))
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
@ -281,14 +283,10 @@ def get_secret( # noqa: PLR0915
print_verbose(f"get_secret_value_response: {secret}")
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try:
secret = client.get_secret_from_google_secret_manager(
secret_name
)
secret = client.get_secret_from_google_secret_manager(secret_name)
print_verbose(f"secret from google secret manager: {secret}")
if secret is None:
raise ValueError(
f"No secret found in Google Secret Manager for {secret_name}"
)
raise ValueError(f"No secret found in Google Secret Manager for {secret_name}")
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
raise e
@ -296,9 +294,7 @@ def get_secret( # noqa: PLR0915
try:
secret = client.sync_read_secret(secret_name=secret_name)
if secret is None:
raise ValueError(
f"No secret found in Hashicorp Secret Manager for {secret_name}"
)
raise ValueError(f"No secret found in Hashicorp Secret Manager for {secret_name}")
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
raise e
@ -323,9 +319,7 @@ def get_secret( # noqa: PLR0915
else:
secret = os.environ.get(secret_name)
secret_value_as_bool = str_to_bool(secret) if secret is not None else None
if secret_value_as_bool is not None and isinstance(
secret_value_as_bool, bool
):
if secret_value_as_bool is not None and isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret

View file

@ -0,0 +1,171 @@
import pytest
import os
from unittest.mock import Mock, patch
from litellm.secret_managers.main import get_secret
import logging
# Set up logging for debugging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Mock HTTPHandler and oidc_cache
class MockHTTPHandler:
def __init__(self, timeout):
self.timeout = timeout
self.status_code = 200
self.text = "mocked_token"
self.json_data = {"value": "mocked_token"}
def get(self, url, params=None, headers=None):
# Store params for audience verification
self.last_params = params
logger.debug(f"MockHTTPHandler.get called with url={url}, params={params}, headers={headers}")
mock_response = Mock()
mock_response.status_code = self.status_code
mock_response.text = self.text
mock_response.json.return_value = self.json_data
return mock_response
@pytest.fixture
def mock_oidc_cache():
cache = Mock()
cache.get_cache.return_value = None
cache.set_cache = Mock()
return cache
@pytest.fixture
def mock_env():
with patch.dict(os.environ, {}, clear=True):
yield os.environ
@patch('litellm.secret_managers.main.oidc_cache')
@patch('litellm.secret_managers.main.HTTPHandler')
def test_oidc_google_success(mock_http_handler, mock_oidc_cache):
mock_oidc_cache.get_cache.return_value = None
mock_handler = MockHTTPHandler(timeout=600.0)
mock_http_handler.return_value = mock_handler
secret_name = "oidc/google/[invalid url, do not cite]"
result = get_secret(secret_name)
assert result == "mocked_token"
assert mock_handler.last_params == {"audience": "[invalid url, do not cite]"}
mock_oidc_cache.set_cache.assert_called_once_with(
key=secret_name, value="mocked_token", ttl=3540
)
@patch('litellm.secret_managers.main.oidc_cache')
def test_oidc_google_cached(mock_oidc_cache):
mock_oidc_cache.get_cache.return_value = "cached_token"
secret_name = "oidc/google/[invalid url, do not cite]"
with patch('litellm.HTTPHandler') as mock_http:
result = get_secret(secret_name)
assert result == "cached_token", f"Expected cached token, got {result}"
mock_oidc_cache.get_cache.assert_called_with(key=secret_name)
mock_http.assert_not_called()
def test_oidc_google_failure(mock_oidc_cache):
mock_handler = MockHTTPHandler(timeout=600.0)
mock_handler.status_code = 400
with patch('litellm.secret_managers.main.HTTPHandler', return_value=mock_handler):
mock_oidc_cache.get_cache.return_value = None
secret_name = "oidc/google/https://example.com/api"
with pytest.raises(ValueError, match="Google OIDC provider failed"):
get_secret(secret_name)
def test_oidc_circleci_success(mock_env):
mock_env["CIRCLE_OIDC_TOKEN"] = "circleci_token"
secret_name = "oidc/circleci/test-audience"
result = get_secret(secret_name)
assert result == "circleci_token"
def test_oidc_circleci_failure():
secret_name = "oidc/circleci/test-audience"
with pytest.raises(ValueError, match="CIRCLE_OIDC_TOKEN not found in environment"):
get_secret(secret_name)
@patch('litellm.secret_managers.main.oidc_cache')
@patch('litellm.secret_managers.main.HTTPHandler')
def test_oidc_github_success(mock_http_handler, mock_oidc_cache, mock_env):
mock_env["ACTIONS_ID_TOKEN_REQUEST_URL"] = "https://github.com/token"
mock_env["ACTIONS_ID_TOKEN_REQUEST_TOKEN"] = "github_token"
mock_oidc_cache.get_cache.return_value = None
mock_handler = MockHTTPHandler(timeout=600.0)
mock_http_handler.return_value = mock_handler
secret_name = "oidc/github/github-audience"
result = get_secret(secret_name)
assert result == "mocked_token", f"Expected token 'mocked_token', got {result}"
assert mock_handler.last_params == {"audience": "github-audience"}
logger.debug(f"set_cache call args: {mock_oidc_cache.set_cache.call_args}")
mock_oidc_cache.set_cache.assert_called_once()
mock_oidc_cache.set_cache.assert_called_with(
key=secret_name, value="mocked_token", ttl=295
)
def test_oidc_github_missing_env():
secret_name = "oidc/github/github-audience"
with pytest.raises(ValueError, match="ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"):
get_secret(secret_name)
def test_oidc_azure_file_success(mock_env, tmp_path):
token_file = tmp_path / "token.txt"
token_file.write_text("azure_token")
mock_env["AZURE_FEDERATED_TOKEN_FILE"] = str(token_file)
secret_name = "oidc/azure/azure-audience"
result = get_secret(secret_name)
assert result == "azure_token"
@patch('litellm.secret_managers.main.get_azure_ad_token_provider')
def test_oidc_azure_ad_token_success(mock_get_azure_ad_token_provider):
mock_token_provider = Mock(return_value="azure_ad_token")
mock_get_azure_ad_token_provider.return_value = mock_token_provider
secret_name = "oidc/azure/api://azure-audience"
result = get_secret(secret_name)
assert result == "azure_ad_token"
mock_get_azure_ad_token_provider.assert_called_once_with(azure_scope="api://azure-audience")
mock_token_provider.assert_called_once_with()
def test_oidc_file_success(tmp_path):
token_file = tmp_path / "token.txt"
token_file.write_text("file_token")
secret_name = f"oidc/file/{token_file}"
result = get_secret(secret_name)
assert result == "file_token"
def test_oidc_env_success(mock_env):
mock_env["CUSTOM_TOKEN"] = "env_token"
secret_name = "oidc/env/CUSTOM_TOKEN"
result = get_secret(secret_name)
assert result == "env_token"
def test_oidc_env_path_success(mock_env, tmp_path):
token_file = tmp_path / "token.txt"
token_file.write_text("env_path_token")
mock_env["TOKEN_PATH"] = str(token_file)
secret_name = "oidc/env_path/TOKEN_PATH"
result = get_secret(secret_name)
assert result == "env_path_token"
def test_unsupported_oidc_provider():
secret_name = "oidc/unsupported/unsupported-audience"
with pytest.raises(ValueError, match="Unsupported OIDC provider"):
get_secret(secret_name)