Azure Service Principal with Secret authentication workflow. (#5131) (#5437)

* Azure Service Principal with Secret authentication workflow. (#5131)

* Implement Azure Service Principal with Secret authentication workflow.

* Use `ClientSecretCredential` instead of `DefaultAzureCredential`.

* Move imports into the function.

* Add type hint for `azure_ad_token_provider`.

* Add unit test for router initialization and sample completion using Azure Service Principal with Secret authentication workflow.

* Add unit test for router initialization with neither API key nor using Azure Service Principal with Secret authentication workflow.

* fix(client_initializtion_utils.py): fix typing + overrides

* test: fix linting errors

* fix(client_initialization_utils.py): fix client init azure ad token logic

* fix(router_client_initialization.py): add flag check for reading azure ad token from environment

* test(test_streaming.py): skip end of life bedrock model

* test(test_router_client_init.py): add correct flag to test

---------

Co-authored-by: kzych-inpost <142029278+kzych-inpost@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-09-02 14:29:00 -07:00 committed by GitHub
parent 2797b30a50
commit 02f288a8a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 193 additions and 5 deletions

View file

@ -116,6 +116,7 @@ ssl_certificate: Optional[str] = None
disable_streaming_logging: bool = False
in_memory_llm_clients_cache: dict = {}
safe_memory_mode: bool = False
enable_azure_ad_token_refresh: Optional[bool] = False
### DEFAULT AZURE API VERSION ###
AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest
### COHERE EMBEDDINGS DEFAULT TYPE ###

View file

@ -0,0 +1,32 @@
import os
from typing import Callable
def get_azure_ad_token_provider() -> Callable[[], str]:
"""
Get Azure AD token provider based on Service Principal with Secret workflow.
Based on: https://github.com/openai/openai-python/blob/main/examples/azure_ad.py
See Also:
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#service-principal-with-secret;
https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python.
Returns:
Callable that returns a temporary authentication token.
"""
from azure.identity import ClientSecretCredential
from azure.identity import get_bearer_token_provider
try:
credential = ClientSecretCredential(
client_id=os.environ["AZURE_CLIENT_ID"],
client_secret=os.environ["AZURE_CLIENT_SECRET"],
tenant_id=os.environ["AZURE_TENANT_ID"],
)
except KeyError as e:
raise ValueError("Missing environment variable required by Azure AD workflow.") from e
return get_bearer_token_provider(
credential,
"https://cognitiveservices.azure.com/.default",
)

View file

@ -1,7 +1,7 @@
import asyncio
import os
import traceback
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional
import httpx
import openai
@ -9,6 +9,9 @@ import openai
import litellm
from litellm._logging import verbose_router_logger
from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.proxy.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider,
)
from litellm.utils import calculate_max_parallel_requests
if TYPE_CHECKING:
@ -172,7 +175,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization
azure_ad_token_provider = None
azure_ad_token_provider: Optional[Callable[[], str]] = None
if litellm_params.get("tenant_id"):
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
@ -197,6 +200,16 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
if azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
elif (
azure_ad_token_provider is None
and litellm.enable_azure_ad_token_refresh is True
):
try:
azure_ad_token_provider = get_azure_ad_token_provider()
except ValueError:
verbose_router_logger.debug(
"Azure AD Token Provider could not be used."
)
if api_version is None:
api_version = os.getenv(
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
@ -211,6 +224,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=timeout,
@ -236,6 +250,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=timeout,
@ -258,6 +273,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
@ -283,6 +299,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
@ -313,6 +330,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
"azure_endpoint": api_base,
"api_version": api_version,
"azure_ad_token": azure_ad_token,
"azure_ad_token_provider": azure_ad_token_provider,
}
if azure_ad_token_provider is not None:

View file

@ -1,17 +1,25 @@
#### What this tests ####
# This tests client initialization + reinitialization on the router
import asyncio
import os
#### What this tests ####
# This tests caching on the router
import sys, os, time
import traceback, asyncio
import sys
import time
import traceback
from typing import Dict
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from openai.lib.azure import OpenAIError
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
from litellm import APIConnectionError, Router
async def test_router_init():
@ -75,4 +83,133 @@ async def test_router_init():
)
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os")
def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret(
mocked_os_lib: MagicMock,
) -> None:
"""
Test router initialization with neither API key nor using Azure Service Principal with Secret authentication
workflow (having not provided environment variables).
"""
litellm.enable_azure_ad_token_refresh = True
# mock EMPTY environment variables
environment_variables_expected_to_use: Dict = {}
mocked_environ = PropertyMock(return_value=environment_variables_expected_to_use)
# Because of the way mock attributes are stored you cant directly attach a PropertyMock to a mock object.
# https://docs.python.org/3.11/library/unittest.mock.html#unittest.mock.PropertyMock
type(mocked_os_lib).environ = mocked_environ
# define the model list
model_list = [
{
# test case for Azure Service Principal with Secret authentication
"model_name": "gpt-4o",
"litellm_params": {
# checkout there is no api_key here -
# AZURE_CLIENT_ID, AZURE_CLIENT_SECRET and AZURE_TENANT_ID environment variables should be used instead
"model": "gpt-4o",
"base_model": "gpt-4o",
"api_base": "test_api_base",
"api_version": "2024-01-01-preview",
"custom_llm_provider": "azure",
},
"model_info": {"mode": "completion"},
},
]
# initialize the router
with pytest.raises(OpenAIError):
# it would raise an error, because environment variables were not provided => azure_ad_token_provider is None
Router(model_list=model_list)
# check if the mocked environment variables were reached
mocked_environ.assert_called()
@patch("azure.identity.get_bearer_token_provider")
@patch("azure.identity.ClientSecretCredential")
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os")
def test_router_init_azure_service_principal_with_secret_with_environment_variables(
mocked_os_lib: MagicMock,
mocked_credential: MagicMock,
mocked_get_bearer_token_provider: MagicMock,
) -> None:
"""
Test router initialization and sample completion using Azure Service Principal with Secret authentication workflow,
having provided the (mocked) credentials in environment variables and not provided any API key.
To allow for local testing without real credentials, first must mock Azure SDK authentication functions
and environment variables.
"""
litellm.enable_azure_ad_token_refresh = True
# mock the token provider function
mocked_func_generating_token = MagicMock(return_value="test_token")
mocked_get_bearer_token_provider.return_value = mocked_func_generating_token
# mock the environment variables with mocked credentials
environment_variables_expected_to_use = {
"AZURE_CLIENT_ID": "test_client_id",
"AZURE_CLIENT_SECRET": "test_client_secret",
"AZURE_TENANT_ID": "test_tenant_id",
}
mocked_environ = PropertyMock(return_value=environment_variables_expected_to_use)
# Because of the way mock attributes are stored you cant directly attach a PropertyMock to a mock object.
# https://docs.python.org/3.11/library/unittest.mock.html#unittest.mock.PropertyMock
type(mocked_os_lib).environ = mocked_environ
# define the model list
model_list = [
{
# test case for Azure Service Principal with Secret authentication
"model_name": "gpt-4o",
"litellm_params": {
# checkout there is no api_key here -
# AZURE_CLIENT_ID, AZURE_CLIENT_SECRET and AZURE_TENANT_ID environment variables should be used instead
"model": "gpt-4o",
"base_model": "gpt-4o",
"api_base": "test_api_base",
"api_version": "2024-01-01-preview",
"custom_llm_provider": "azure",
},
"model_info": {"mode": "completion"},
},
]
# initialize the router
router = Router(model_list=model_list)
# first check if environment variables were used at all
mocked_environ.assert_called()
# then check if the client was initialized with the correct environment variables
mocked_credential.assert_called_with(
**{
"client_id": environment_variables_expected_to_use["AZURE_CLIENT_ID"],
"client_secret": environment_variables_expected_to_use[
"AZURE_CLIENT_SECRET"
],
"tenant_id": environment_variables_expected_to_use["AZURE_TENANT_ID"],
}
)
# check if the token provider was called at all
mocked_get_bearer_token_provider.assert_called()
# then check if the token provider was initialized with the mocked credential
for call_args in mocked_get_bearer_token_provider.call_args_list:
assert call_args.args[0] == mocked_credential.return_value
# however, at this point token should not be fetched yet
mocked_func_generating_token.assert_not_called()
# now let's try to make a completion call
deployment = model_list[0]
model = deployment["model_name"]
messages = [
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
]
with pytest.raises(APIConnectionError):
# of course, it will raise an error, because URL is mocked
router.completion(model=model, messages=messages, temperature=1) # type: ignore
# finally verify if the mocked token was used by Azure SDK
mocked_func_generating_token.assert_called()
# asyncio.run(test_router_init())