From 02f288a8a3335e0629ed781ba41f3352952e4ec4 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Mon, 2 Sep 2024 14:29:00 -0700 Subject: [PATCH] Azure Service Principal with Secret authentication workflow. (#5131) (#5437) * Azure Service Principal with Secret authentication workflow. (#5131) * Implement Azure Service Principal with Secret authentication workflow. * Use `ClientSecretCredential` instead of `DefaultAzureCredential`. * Move imports into the function. * Add type hint for `azure_ad_token_provider`. * Add unit test for router initialization and sample completion using Azure Service Principal with Secret authentication workflow. * Add unit test for router initialization with neither API key nor using Azure Service Principal with Secret authentication workflow. * fix(client_initializtion_utils.py): fix typing + overrides * test: fix linting errors * fix(client_initialization_utils.py): fix client init azure ad token logic * fix(router_client_initialization.py): add flag check for reading azure ad token from environment * test(test_streaming.py): skip end of life bedrock model * test(test_router_client_init.py): add correct flag to test --------- Co-authored-by: kzych-inpost <142029278+kzych-inpost@users.noreply.github.com> --- litellm/__init__.py | 1 + .../get_azure_ad_token_provider.py | 32 ++++ .../client_initalization_utils.py | 22 ++- litellm/tests/test_router_client_init.py | 143 +++++++++++++++++- 4 files changed, 193 insertions(+), 5 deletions(-) create mode 100644 litellm/proxy/secret_managers/get_azure_ad_token_provider.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 3f22e41b6..f67e2ca83 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -116,6 +116,7 @@ ssl_certificate: Optional[str] = None disable_streaming_logging: bool = False in_memory_llm_clients_cache: dict = {} safe_memory_mode: bool = False +enable_azure_ad_token_refresh: Optional[bool] = False ### DEFAULT AZURE API VERSION ### AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest ### COHERE EMBEDDINGS DEFAULT TYPE ### diff --git a/litellm/proxy/secret_managers/get_azure_ad_token_provider.py b/litellm/proxy/secret_managers/get_azure_ad_token_provider.py new file mode 100644 index 000000000..0ecdae514 --- /dev/null +++ b/litellm/proxy/secret_managers/get_azure_ad_token_provider.py @@ -0,0 +1,32 @@ +import os +from typing import Callable + + +def get_azure_ad_token_provider() -> Callable[[], str]: + """ + Get Azure AD token provider based on Service Principal with Secret workflow. + + Based on: https://github.com/openai/openai-python/blob/main/examples/azure_ad.py + See Also: + https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#service-principal-with-secret; + https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python. + + Returns: + Callable that returns a temporary authentication token. + """ + from azure.identity import ClientSecretCredential + from azure.identity import get_bearer_token_provider + + try: + credential = ClientSecretCredential( + client_id=os.environ["AZURE_CLIENT_ID"], + client_secret=os.environ["AZURE_CLIENT_SECRET"], + tenant_id=os.environ["AZURE_TENANT_ID"], + ) + except KeyError as e: + raise ValueError("Missing environment variable required by Azure AD workflow.") from e + + return get_bearer_token_provider( + credential, + "https://cognitiveservices.azure.com/.default", + ) diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py index e98b8b4dd..bd5337b33 100644 --- a/litellm/router_utils/client_initalization_utils.py +++ b/litellm/router_utils/client_initalization_utils.py @@ -1,7 +1,7 @@ import asyncio import os import traceback -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional import httpx import openai @@ -9,6 +9,9 @@ import openai import litellm from litellm._logging import verbose_router_logger from litellm.llms.azure import get_azure_ad_token_from_oidc +from litellm.proxy.secret_managers.get_azure_ad_token_provider import ( + get_azure_ad_token_provider, +) from litellm.utils import calculate_max_parallel_requests if TYPE_CHECKING: @@ -172,7 +175,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): organization_env_name = organization.replace("os.environ/", "") organization = litellm.get_secret(organization_env_name) litellm_params["organization"] = organization - azure_ad_token_provider = None + azure_ad_token_provider: Optional[Callable[[], str]] = None if litellm_params.get("tenant_id"): verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth") azure_ad_token_provider = get_azure_ad_token_from_entrata_id( @@ -197,6 +200,16 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): if azure_ad_token is not None: if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + elif ( + azure_ad_token_provider is None + and litellm.enable_azure_ad_token_refresh is True + ): + try: + azure_ad_token_provider = get_azure_ad_token_provider() + except ValueError: + verbose_router_logger.debug( + "Azure AD Token Provider could not be used." + ) if api_version is None: api_version = os.getenv( "AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION @@ -211,6 +224,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): _client = openai.AsyncAzureOpenAI( api_key=api_key, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, base_url=api_base, api_version=api_version, timeout=timeout, @@ -236,6 +250,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): _client = openai.AzureOpenAI( # type: ignore api_key=api_key, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, base_url=api_base, api_version=api_version, timeout=timeout, @@ -258,6 +273,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): _client = openai.AsyncAzureOpenAI( # type: ignore api_key=api_key, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, base_url=api_base, api_version=api_version, timeout=stream_timeout, @@ -283,6 +299,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): _client = openai.AzureOpenAI( # type: ignore api_key=api_key, azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, base_url=api_base, api_version=api_version, timeout=stream_timeout, @@ -313,6 +330,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): "azure_endpoint": api_base, "api_version": api_version, "azure_ad_token": azure_ad_token, + "azure_ad_token_provider": azure_ad_token_provider, } if azure_ad_token_provider is not None: diff --git a/litellm/tests/test_router_client_init.py b/litellm/tests/test_router_client_init.py index 79f8ba8b2..0984e406d 100644 --- a/litellm/tests/test_router_client_init.py +++ b/litellm/tests/test_router_client_init.py @@ -1,17 +1,25 @@ #### What this tests #### # This tests client initialization + reinitialization on the router +import asyncio +import os + #### What this tests #### # This tests caching on the router -import sys, os, time -import traceback, asyncio +import sys +import time +import traceback +from typing import Dict +from unittest.mock import MagicMock, PropertyMock, patch + import pytest +from openai.lib.azure import OpenAIError sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import litellm -from litellm import Router +from litellm import APIConnectionError, Router async def test_router_init(): @@ -75,4 +83,133 @@ async def test_router_init(): ) +@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os") +def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret( + mocked_os_lib: MagicMock, +) -> None: + """ + Test router initialization with neither API key nor using Azure Service Principal with Secret authentication + workflow (having not provided environment variables). + """ + litellm.enable_azure_ad_token_refresh = True + # mock EMPTY environment variables + environment_variables_expected_to_use: Dict = {} + mocked_environ = PropertyMock(return_value=environment_variables_expected_to_use) + # Because of the way mock attributes are stored you can’t directly attach a PropertyMock to a mock object. + # https://docs.python.org/3.11/library/unittest.mock.html#unittest.mock.PropertyMock + type(mocked_os_lib).environ = mocked_environ + + # define the model list + model_list = [ + { + # test case for Azure Service Principal with Secret authentication + "model_name": "gpt-4o", + "litellm_params": { + # checkout there is no api_key here - + # AZURE_CLIENT_ID, AZURE_CLIENT_SECRET and AZURE_TENANT_ID environment variables should be used instead + "model": "gpt-4o", + "base_model": "gpt-4o", + "api_base": "test_api_base", + "api_version": "2024-01-01-preview", + "custom_llm_provider": "azure", + }, + "model_info": {"mode": "completion"}, + }, + ] + + # initialize the router + with pytest.raises(OpenAIError): + # it would raise an error, because environment variables were not provided => azure_ad_token_provider is None + Router(model_list=model_list) + + # check if the mocked environment variables were reached + mocked_environ.assert_called() + + +@patch("azure.identity.get_bearer_token_provider") +@patch("azure.identity.ClientSecretCredential") +@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os") +def test_router_init_azure_service_principal_with_secret_with_environment_variables( + mocked_os_lib: MagicMock, + mocked_credential: MagicMock, + mocked_get_bearer_token_provider: MagicMock, +) -> None: + """ + Test router initialization and sample completion using Azure Service Principal with Secret authentication workflow, + having provided the (mocked) credentials in environment variables and not provided any API key. + + To allow for local testing without real credentials, first must mock Azure SDK authentication functions + and environment variables. + """ + litellm.enable_azure_ad_token_refresh = True + # mock the token provider function + mocked_func_generating_token = MagicMock(return_value="test_token") + mocked_get_bearer_token_provider.return_value = mocked_func_generating_token + + # mock the environment variables with mocked credentials + environment_variables_expected_to_use = { + "AZURE_CLIENT_ID": "test_client_id", + "AZURE_CLIENT_SECRET": "test_client_secret", + "AZURE_TENANT_ID": "test_tenant_id", + } + mocked_environ = PropertyMock(return_value=environment_variables_expected_to_use) + # Because of the way mock attributes are stored you can’t directly attach a PropertyMock to a mock object. + # https://docs.python.org/3.11/library/unittest.mock.html#unittest.mock.PropertyMock + type(mocked_os_lib).environ = mocked_environ + + # define the model list + model_list = [ + { + # test case for Azure Service Principal with Secret authentication + "model_name": "gpt-4o", + "litellm_params": { + # checkout there is no api_key here - + # AZURE_CLIENT_ID, AZURE_CLIENT_SECRET and AZURE_TENANT_ID environment variables should be used instead + "model": "gpt-4o", + "base_model": "gpt-4o", + "api_base": "test_api_base", + "api_version": "2024-01-01-preview", + "custom_llm_provider": "azure", + }, + "model_info": {"mode": "completion"}, + }, + ] + + # initialize the router + router = Router(model_list=model_list) + + # first check if environment variables were used at all + mocked_environ.assert_called() + # then check if the client was initialized with the correct environment variables + mocked_credential.assert_called_with( + **{ + "client_id": environment_variables_expected_to_use["AZURE_CLIENT_ID"], + "client_secret": environment_variables_expected_to_use[ + "AZURE_CLIENT_SECRET" + ], + "tenant_id": environment_variables_expected_to_use["AZURE_TENANT_ID"], + } + ) + # check if the token provider was called at all + mocked_get_bearer_token_provider.assert_called() + # then check if the token provider was initialized with the mocked credential + for call_args in mocked_get_bearer_token_provider.call_args_list: + assert call_args.args[0] == mocked_credential.return_value + # however, at this point token should not be fetched yet + mocked_func_generating_token.assert_not_called() + + # now let's try to make a completion call + deployment = model_list[0] + model = deployment["model_name"] + messages = [ + {"role": "user", "content": f"write a one sentence poem {time.time()}?"} + ] + with pytest.raises(APIConnectionError): + # of course, it will raise an error, because URL is mocked + router.completion(model=model, messages=messages, temperature=1) # type: ignore + + # finally verify if the mocked token was used by Azure SDK + mocked_func_generating_token.assert_called() + + # asyncio.run(test_router_init())