mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 4471b45e0c
into b82af5b826
This commit is contained in:
commit
2ee258df3d
4 changed files with 222 additions and 37 deletions
|
@ -19,6 +19,7 @@ LiteLLM supports the following OIDC identity providers:
|
|||
| CircleCI v2 | `circleci_v2`| No |
|
||||
| GitHub Actions | `github` | Yes |
|
||||
| Azure Kubernetes Service | `azure` | No |
|
||||
| Azure AD | `azure` | Yes |
|
||||
| File | `file` | No |
|
||||
| Environment Variable | `env` | No |
|
||||
| Environment Path | `env_path` | No |
|
||||
|
@ -261,3 +262,15 @@ The custom role below is the recommended minimum permissions for the Azure appli
|
|||
_Note: Your UUIDs will be different._
|
||||
|
||||
Please contact us for paid enterprise support if you need help setting up Azure AD applications.
|
||||
|
||||
### Azure AD -> Amazon Bedrock
|
||||
```yaml
|
||||
model list:
|
||||
- model_name: aws/claude-3-5-sonnet
|
||||
litellm_params:
|
||||
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
aws_region_name: "eu-central-1"
|
||||
aws_role_name: "arn:aws:iam::12345678:role/bedrock-role"
|
||||
aws_web_identity_token: "oidc/azure/api://123-456-789-9d04"
|
||||
aws_session_name: "litellm-session"
|
||||
```
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import os
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
def get_azure_ad_token_provider() -> Callable[[], str]:
|
||||
def get_azure_ad_token_provider(azure_scope: Optional[str] = None) -> Callable[[], str]:
|
||||
"""
|
||||
Get Azure AD token provider based on Service Principal with Secret workflow.
|
||||
|
||||
|
@ -11,15 +11,22 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
|
|||
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#service-principal-with-secret;
|
||||
https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python.
|
||||
|
||||
Args:
|
||||
azure_scope (str, optional): The Azure scope to request token for.
|
||||
Defaults to environment variable AZURE_SCOPE or
|
||||
"https://cognitiveservices.azure.com/.default".
|
||||
|
||||
Returns:
|
||||
Callable that returns a temporary authentication token.
|
||||
"""
|
||||
import azure.identity as identity
|
||||
from azure.identity import get_bearer_token_provider
|
||||
|
||||
azure_scope = os.environ.get(
|
||||
"AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
if azure_scope is None:
|
||||
azure_scope = os.environ.get(
|
||||
"AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
|
||||
cred = os.environ.get("AZURE_CREDENTIAL", "ClientSecretCredential")
|
||||
|
||||
cred_cls = getattr(identity, cred)
|
||||
|
|
|
@ -12,6 +12,9 @@ from litellm._logging import print_verbose, verbose_logger
|
|||
from litellm.caching.caching import DualCache
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
|
||||
oidc_cache = DualCache()
|
||||
|
||||
|
@ -102,6 +105,7 @@ def get_secret( # noqa: PLR0915
|
|||
if secret_name.startswith("oidc/"):
|
||||
secret_name_split = secret_name.replace("oidc/", "")
|
||||
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
|
||||
oidc_aud = "/".join(secret_name_split.split("/")[1:])
|
||||
# TODO: Add caching for HTTP requests
|
||||
if oidc_provider == "google":
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
|
@ -137,10 +141,7 @@ def get_secret( # noqa: PLR0915
|
|||
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
|
||||
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||
if (
|
||||
actions_id_token_request_url is None
|
||||
or actions_id_token_request_token is None
|
||||
):
|
||||
if actions_id_token_request_url is None or actions_id_token_request_token is None:
|
||||
raise ValueError(
|
||||
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
|
||||
)
|
||||
|
@ -168,7 +169,19 @@ def get_secret( # noqa: PLR0915
|
|||
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
|
||||
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
|
||||
if azure_federated_token_file is None:
|
||||
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
|
||||
verbose_logger.warning(
|
||||
"AZURE_FEDERATED_TOKEN_FILE not found in environment will use Azure AD token provider"
|
||||
)
|
||||
azure_token_provider = get_azure_ad_token_provider(azure_scope=oidc_aud)
|
||||
try:
|
||||
oidc_token = azure_token_provider()
|
||||
if oidc_token is None:
|
||||
raise ValueError("Azure OIDC provider returned None token")
|
||||
return oidc_token
|
||||
except Exception as e:
|
||||
error_msg = f"Azure OIDC provider failed: {str(e)}"
|
||||
verbose_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
with open(azure_federated_token_file, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
|
@ -195,10 +208,7 @@ def get_secret( # noqa: PLR0915
|
|||
raise ValueError("Unsupported OIDC provider")
|
||||
|
||||
try:
|
||||
if (
|
||||
_should_read_secret_from_secret_manager()
|
||||
and litellm.secret_manager_client is not None
|
||||
):
|
||||
if _should_read_secret_from_secret_manager() and litellm.secret_manager_client is not None:
|
||||
try:
|
||||
client = litellm.secret_manager_client
|
||||
key_manager = "local"
|
||||
|
@ -224,9 +234,7 @@ def get_secret( # noqa: PLR0915
|
|||
):
|
||||
encrypted_secret: Any = os.getenv(secret_name)
|
||||
if encrypted_secret is None:
|
||||
raise ValueError(
|
||||
"Google KMS requires the encrypted secret to be in the environment!"
|
||||
)
|
||||
raise ValueError("Google KMS requires the encrypted secret to be in the environment!")
|
||||
b64_flag = _is_base64(encrypted_secret)
|
||||
if b64_flag is True: # if passed in as encoded b64 string
|
||||
encrypted_secret = base64.b64decode(encrypted_secret)
|
||||
|
@ -241,20 +249,14 @@ def get_secret( # noqa: PLR0915
|
|||
"ciphertext": ciphertext,
|
||||
}
|
||||
)
|
||||
secret = response.plaintext.decode(
|
||||
"utf-8"
|
||||
) # assumes the original value was encoded with utf-8
|
||||
secret = response.plaintext.decode("utf-8") # assumes the original value was encoded with utf-8
|
||||
elif key_manager == KeyManagementSystem.AWS_KMS.value:
|
||||
"""
|
||||
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
|
||||
"""
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception(
|
||||
"AWS KMS - Encrypted Value of Key={} is None".format(
|
||||
secret_name
|
||||
)
|
||||
)
|
||||
raise Exception("AWS KMS - Encrypted Value of Key={} is None".format(secret_name))
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
|
@ -281,14 +283,10 @@ def get_secret( # noqa: PLR0915
|
|||
print_verbose(f"get_secret_value_response: {secret}")
|
||||
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
|
||||
try:
|
||||
secret = client.get_secret_from_google_secret_manager(
|
||||
secret_name
|
||||
)
|
||||
secret = client.get_secret_from_google_secret_manager(secret_name)
|
||||
print_verbose(f"secret from google secret manager: {secret}")
|
||||
if secret is None:
|
||||
raise ValueError(
|
||||
f"No secret found in Google Secret Manager for {secret_name}"
|
||||
)
|
||||
raise ValueError(f"No secret found in Google Secret Manager for {secret_name}")
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
raise e
|
||||
|
@ -296,9 +294,7 @@ def get_secret( # noqa: PLR0915
|
|||
try:
|
||||
secret = client.sync_read_secret(secret_name=secret_name)
|
||||
if secret is None:
|
||||
raise ValueError(
|
||||
f"No secret found in Hashicorp Secret Manager for {secret_name}"
|
||||
)
|
||||
raise ValueError(f"No secret found in Hashicorp Secret Manager for {secret_name}")
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
raise e
|
||||
|
@ -323,9 +319,7 @@ def get_secret( # noqa: PLR0915
|
|||
else:
|
||||
secret = os.environ.get(secret_name)
|
||||
secret_value_as_bool = str_to_bool(secret) if secret is not None else None
|
||||
if secret_value_as_bool is not None and isinstance(
|
||||
secret_value_as_bool, bool
|
||||
):
|
||||
if secret_value_as_bool is not None and isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
|
|
171
tests/litellm/secrets_managers/test_main_oidc.py
Normal file
171
tests/litellm/secrets_managers/test_main_oidc.py
Normal file
|
@ -0,0 +1,171 @@
|
|||
import pytest
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
from litellm.secret_managers.main import get_secret
|
||||
import logging
|
||||
|
||||
# Set up logging for debugging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mock HTTPHandler and oidc_cache
|
||||
class MockHTTPHandler:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
self.status_code = 200
|
||||
self.text = "mocked_token"
|
||||
self.json_data = {"value": "mocked_token"}
|
||||
|
||||
def get(self, url, params=None, headers=None):
|
||||
# Store params for audience verification
|
||||
self.last_params = params
|
||||
logger.debug(f"MockHTTPHandler.get called with url={url}, params={params}, headers={headers}")
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = self.status_code
|
||||
mock_response.text = self.text
|
||||
mock_response.json.return_value = self.json_data
|
||||
return mock_response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oidc_cache():
|
||||
cache = Mock()
|
||||
cache.get_cache.return_value = None
|
||||
cache.set_cache = Mock()
|
||||
return cache
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
yield os.environ
|
||||
|
||||
@patch('litellm.secret_managers.main.oidc_cache')
|
||||
@patch('litellm.secret_managers.main.HTTPHandler')
|
||||
def test_oidc_google_success(mock_http_handler, mock_oidc_cache):
|
||||
mock_oidc_cache.get_cache.return_value = None
|
||||
mock_handler = MockHTTPHandler(timeout=600.0)
|
||||
mock_http_handler.return_value = mock_handler
|
||||
secret_name = "oidc/google/[invalid url, do not cite]"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "mocked_token"
|
||||
assert mock_handler.last_params == {"audience": "[invalid url, do not cite]"}
|
||||
mock_oidc_cache.set_cache.assert_called_once_with(
|
||||
key=secret_name, value="mocked_token", ttl=3540
|
||||
)
|
||||
|
||||
@patch('litellm.secret_managers.main.oidc_cache')
|
||||
def test_oidc_google_cached(mock_oidc_cache):
|
||||
mock_oidc_cache.get_cache.return_value = "cached_token"
|
||||
|
||||
secret_name = "oidc/google/[invalid url, do not cite]"
|
||||
with patch('litellm.HTTPHandler') as mock_http:
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "cached_token", f"Expected cached token, got {result}"
|
||||
mock_oidc_cache.get_cache.assert_called_with(key=secret_name)
|
||||
mock_http.assert_not_called()
|
||||
|
||||
def test_oidc_google_failure(mock_oidc_cache):
|
||||
mock_handler = MockHTTPHandler(timeout=600.0)
|
||||
mock_handler.status_code = 400
|
||||
|
||||
with patch('litellm.secret_managers.main.HTTPHandler', return_value=mock_handler):
|
||||
mock_oidc_cache.get_cache.return_value = None
|
||||
secret_name = "oidc/google/https://example.com/api"
|
||||
|
||||
with pytest.raises(ValueError, match="Google OIDC provider failed"):
|
||||
get_secret(secret_name)
|
||||
|
||||
def test_oidc_circleci_success(mock_env):
|
||||
mock_env["CIRCLE_OIDC_TOKEN"] = "circleci_token"
|
||||
|
||||
secret_name = "oidc/circleci/test-audience"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "circleci_token"
|
||||
|
||||
def test_oidc_circleci_failure():
|
||||
secret_name = "oidc/circleci/test-audience"
|
||||
|
||||
with pytest.raises(ValueError, match="CIRCLE_OIDC_TOKEN not found in environment"):
|
||||
get_secret(secret_name)
|
||||
|
||||
@patch('litellm.secret_managers.main.oidc_cache')
|
||||
@patch('litellm.secret_managers.main.HTTPHandler')
|
||||
def test_oidc_github_success(mock_http_handler, mock_oidc_cache, mock_env):
|
||||
mock_env["ACTIONS_ID_TOKEN_REQUEST_URL"] = "https://github.com/token"
|
||||
mock_env["ACTIONS_ID_TOKEN_REQUEST_TOKEN"] = "github_token"
|
||||
mock_oidc_cache.get_cache.return_value = None
|
||||
mock_handler = MockHTTPHandler(timeout=600.0)
|
||||
mock_http_handler.return_value = mock_handler
|
||||
|
||||
secret_name = "oidc/github/github-audience"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "mocked_token", f"Expected token 'mocked_token', got {result}"
|
||||
assert mock_handler.last_params == {"audience": "github-audience"}
|
||||
logger.debug(f"set_cache call args: {mock_oidc_cache.set_cache.call_args}")
|
||||
mock_oidc_cache.set_cache.assert_called_once()
|
||||
mock_oidc_cache.set_cache.assert_called_with(
|
||||
key=secret_name, value="mocked_token", ttl=295
|
||||
)
|
||||
|
||||
def test_oidc_github_missing_env():
|
||||
secret_name = "oidc/github/github-audience"
|
||||
|
||||
with pytest.raises(ValueError, match="ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"):
|
||||
get_secret(secret_name)
|
||||
|
||||
def test_oidc_azure_file_success(mock_env, tmp_path):
|
||||
token_file = tmp_path / "token.txt"
|
||||
token_file.write_text("azure_token")
|
||||
mock_env["AZURE_FEDERATED_TOKEN_FILE"] = str(token_file)
|
||||
|
||||
secret_name = "oidc/azure/azure-audience"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "azure_token"
|
||||
|
||||
@patch('litellm.secret_managers.main.get_azure_ad_token_provider')
|
||||
def test_oidc_azure_ad_token_success(mock_get_azure_ad_token_provider):
|
||||
mock_token_provider = Mock(return_value="azure_ad_token")
|
||||
mock_get_azure_ad_token_provider.return_value = mock_token_provider
|
||||
secret_name = "oidc/azure/api://azure-audience"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "azure_ad_token"
|
||||
mock_get_azure_ad_token_provider.assert_called_once_with(azure_scope="api://azure-audience")
|
||||
mock_token_provider.assert_called_once_with()
|
||||
|
||||
def test_oidc_file_success(tmp_path):
|
||||
token_file = tmp_path / "token.txt"
|
||||
token_file.write_text("file_token")
|
||||
|
||||
secret_name = f"oidc/file/{token_file}"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "file_token"
|
||||
|
||||
def test_oidc_env_success(mock_env):
|
||||
mock_env["CUSTOM_TOKEN"] = "env_token"
|
||||
|
||||
secret_name = "oidc/env/CUSTOM_TOKEN"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "env_token"
|
||||
|
||||
def test_oidc_env_path_success(mock_env, tmp_path):
|
||||
token_file = tmp_path / "token.txt"
|
||||
token_file.write_text("env_path_token")
|
||||
mock_env["TOKEN_PATH"] = str(token_file)
|
||||
|
||||
secret_name = "oidc/env_path/TOKEN_PATH"
|
||||
result = get_secret(secret_name)
|
||||
|
||||
assert result == "env_path_token"
|
||||
|
||||
def test_unsupported_oidc_provider():
|
||||
secret_name = "oidc/unsupported/unsupported-audience"
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported OIDC provider"):
|
||||
get_secret(secret_name)
|
Loading…
Add table
Add a link
Reference in a new issue