forked from phoenix/litellm-mirror
* Azure Service Principal with Secret authentication workflow. (#5131) * Implement Azure Service Principal with Secret authentication workflow. * Use `ClientSecretCredential` instead of `DefaultAzureCredential`. * Move imports into the function. * Add type hint for `azure_ad_token_provider`. * Add unit test for router initialization and sample completion using Azure Service Principal with Secret authentication workflow. * Add unit test for router initialization with neither API key nor using Azure Service Principal with Secret authentication workflow. * fix(client_initializtion_utils.py): fix typing + overrides * test: fix linting errors * fix(client_initialization_utils.py): fix client init azure ad token logic * fix(router_client_initialization.py): add flag check for reading azure ad token from environment * test(test_streaming.py): skip end of life bedrock model * test(test_router_client_init.py): add correct flag to test --------- Co-authored-by: kzych-inpost <142029278+kzych-inpost@users.noreply.github.com>
This commit is contained in:
parent
2797b30a50
commit
02f288a8a3
4 changed files with 193 additions and 5 deletions
|
@ -116,6 +116,7 @@ ssl_certificate: Optional[str] = None
|
|||
disable_streaming_logging: bool = False
|
||||
in_memory_llm_clients_cache: dict = {}
|
||||
safe_memory_mode: bool = False
|
||||
enable_azure_ad_token_refresh: Optional[bool] = False
|
||||
### DEFAULT AZURE API VERSION ###
|
||||
AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest
|
||||
### COHERE EMBEDDINGS DEFAULT TYPE ###
|
||||
|
|
32
litellm/proxy/secret_managers/get_azure_ad_token_provider.py
Normal file
32
litellm/proxy/secret_managers/get_azure_ad_token_provider.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import os
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def get_azure_ad_token_provider() -> Callable[[], str]:
|
||||
"""
|
||||
Get Azure AD token provider based on Service Principal with Secret workflow.
|
||||
|
||||
Based on: https://github.com/openai/openai-python/blob/main/examples/azure_ad.py
|
||||
See Also:
|
||||
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#service-principal-with-secret;
|
||||
https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python.
|
||||
|
||||
Returns:
|
||||
Callable that returns a temporary authentication token.
|
||||
"""
|
||||
from azure.identity import ClientSecretCredential
|
||||
from azure.identity import get_bearer_token_provider
|
||||
|
||||
try:
|
||||
credential = ClientSecretCredential(
|
||||
client_id=os.environ["AZURE_CLIENT_ID"],
|
||||
client_secret=os.environ["AZURE_CLIENT_SECRET"],
|
||||
tenant_id=os.environ["AZURE_TENANT_ID"],
|
||||
)
|
||||
except KeyError as e:
|
||||
raise ValueError("Missing environment variable required by Azure AD workflow.") from e
|
||||
|
||||
return get_bearer_token_provider(
|
||||
credential,
|
||||
"https://cognitiveservices.azure.com/.default",
|
||||
)
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import os
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
@ -9,6 +9,9 @@ import openai
|
|||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.proxy.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
from litellm.utils import calculate_max_parallel_requests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -172,7 +175,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
organization_env_name = organization.replace("os.environ/", "")
|
||||
organization = litellm.get_secret(organization_env_name)
|
||||
litellm_params["organization"] = organization
|
||||
azure_ad_token_provider = None
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
if litellm_params.get("tenant_id"):
|
||||
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||
|
@ -197,6 +200,16 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
if azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
elif (
|
||||
azure_ad_token_provider is None
|
||||
and litellm.enable_azure_ad_token_refresh is True
|
||||
):
|
||||
try:
|
||||
azure_ad_token_provider = get_azure_ad_token_provider()
|
||||
except ValueError:
|
||||
verbose_router_logger.debug(
|
||||
"Azure AD Token Provider could not be used."
|
||||
)
|
||||
if api_version is None:
|
||||
api_version = os.getenv(
|
||||
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
|
||||
|
@ -211,6 +224,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
|
@ -236,6 +250,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
|
@ -258,6 +273,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
|
@ -283,6 +299,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
|
@ -313,6 +330,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
"azure_endpoint": api_base,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
}
|
||||
|
||||
if azure_ad_token_provider is not None:
|
||||
|
|
|
@ -1,17 +1,25 @@
|
|||
#### What this tests ####
|
||||
# This tests client initialization + reinitialization on the router
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
#### What this tests ####
|
||||
# This tests caching on the router
|
||||
import sys, os, time
|
||||
import traceback, asyncio
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.lib.azure import OpenAIError
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm import APIConnectionError, Router
|
||||
|
||||
|
||||
async def test_router_init():
|
||||
|
@ -75,4 +83,133 @@ async def test_router_init():
|
|||
)
|
||||
|
||||
|
||||
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os")
|
||||
def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret(
|
||||
mocked_os_lib: MagicMock,
|
||||
) -> None:
|
||||
"""
|
||||
Test router initialization with neither API key nor using Azure Service Principal with Secret authentication
|
||||
workflow (having not provided environment variables).
|
||||
"""
|
||||
litellm.enable_azure_ad_token_refresh = True
|
||||
# mock EMPTY environment variables
|
||||
environment_variables_expected_to_use: Dict = {}
|
||||
mocked_environ = PropertyMock(return_value=environment_variables_expected_to_use)
|
||||
# Because of the way mock attributes are stored you can’t directly attach a PropertyMock to a mock object.
|
||||
# https://docs.python.org/3.11/library/unittest.mock.html#unittest.mock.PropertyMock
|
||||
type(mocked_os_lib).environ = mocked_environ
|
||||
|
||||
# define the model list
|
||||
model_list = [
|
||||
{
|
||||
# test case for Azure Service Principal with Secret authentication
|
||||
"model_name": "gpt-4o",
|
||||
"litellm_params": {
|
||||
# checkout there is no api_key here -
|
||||
# AZURE_CLIENT_ID, AZURE_CLIENT_SECRET and AZURE_TENANT_ID environment variables should be used instead
|
||||
"model": "gpt-4o",
|
||||
"base_model": "gpt-4o",
|
||||
"api_base": "test_api_base",
|
||||
"api_version": "2024-01-01-preview",
|
||||
"custom_llm_provider": "azure",
|
||||
},
|
||||
"model_info": {"mode": "completion"},
|
||||
},
|
||||
]
|
||||
|
||||
# initialize the router
|
||||
with pytest.raises(OpenAIError):
|
||||
# it would raise an error, because environment variables were not provided => azure_ad_token_provider is None
|
||||
Router(model_list=model_list)
|
||||
|
||||
# check if the mocked environment variables were reached
|
||||
mocked_environ.assert_called()
|
||||
|
||||
|
||||
@patch("azure.identity.get_bearer_token_provider")
|
||||
@patch("azure.identity.ClientSecretCredential")
|
||||
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os")
|
||||
def test_router_init_azure_service_principal_with_secret_with_environment_variables(
|
||||
mocked_os_lib: MagicMock,
|
||||
mocked_credential: MagicMock,
|
||||
mocked_get_bearer_token_provider: MagicMock,
|
||||
) -> None:
|
||||
"""
|
||||
Test router initialization and sample completion using Azure Service Principal with Secret authentication workflow,
|
||||
having provided the (mocked) credentials in environment variables and not provided any API key.
|
||||
|
||||
To allow for local testing without real credentials, first must mock Azure SDK authentication functions
|
||||
and environment variables.
|
||||
"""
|
||||
litellm.enable_azure_ad_token_refresh = True
|
||||
# mock the token provider function
|
||||
mocked_func_generating_token = MagicMock(return_value="test_token")
|
||||
mocked_get_bearer_token_provider.return_value = mocked_func_generating_token
|
||||
|
||||
# mock the environment variables with mocked credentials
|
||||
environment_variables_expected_to_use = {
|
||||
"AZURE_CLIENT_ID": "test_client_id",
|
||||
"AZURE_CLIENT_SECRET": "test_client_secret",
|
||||
"AZURE_TENANT_ID": "test_tenant_id",
|
||||
}
|
||||
mocked_environ = PropertyMock(return_value=environment_variables_expected_to_use)
|
||||
# Because of the way mock attributes are stored you can’t directly attach a PropertyMock to a mock object.
|
||||
# https://docs.python.org/3.11/library/unittest.mock.html#unittest.mock.PropertyMock
|
||||
type(mocked_os_lib).environ = mocked_environ
|
||||
|
||||
# define the model list
|
||||
model_list = [
|
||||
{
|
||||
# test case for Azure Service Principal with Secret authentication
|
||||
"model_name": "gpt-4o",
|
||||
"litellm_params": {
|
||||
# checkout there is no api_key here -
|
||||
# AZURE_CLIENT_ID, AZURE_CLIENT_SECRET and AZURE_TENANT_ID environment variables should be used instead
|
||||
"model": "gpt-4o",
|
||||
"base_model": "gpt-4o",
|
||||
"api_base": "test_api_base",
|
||||
"api_version": "2024-01-01-preview",
|
||||
"custom_llm_provider": "azure",
|
||||
},
|
||||
"model_info": {"mode": "completion"},
|
||||
},
|
||||
]
|
||||
|
||||
# initialize the router
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
# first check if environment variables were used at all
|
||||
mocked_environ.assert_called()
|
||||
# then check if the client was initialized with the correct environment variables
|
||||
mocked_credential.assert_called_with(
|
||||
**{
|
||||
"client_id": environment_variables_expected_to_use["AZURE_CLIENT_ID"],
|
||||
"client_secret": environment_variables_expected_to_use[
|
||||
"AZURE_CLIENT_SECRET"
|
||||
],
|
||||
"tenant_id": environment_variables_expected_to_use["AZURE_TENANT_ID"],
|
||||
}
|
||||
)
|
||||
# check if the token provider was called at all
|
||||
mocked_get_bearer_token_provider.assert_called()
|
||||
# then check if the token provider was initialized with the mocked credential
|
||||
for call_args in mocked_get_bearer_token_provider.call_args_list:
|
||||
assert call_args.args[0] == mocked_credential.return_value
|
||||
# however, at this point token should not be fetched yet
|
||||
mocked_func_generating_token.assert_not_called()
|
||||
|
||||
# now let's try to make a completion call
|
||||
deployment = model_list[0]
|
||||
model = deployment["model_name"]
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||
]
|
||||
with pytest.raises(APIConnectionError):
|
||||
# of course, it will raise an error, because URL is mocked
|
||||
router.completion(model=model, messages=messages, temperature=1) # type: ignore
|
||||
|
||||
# finally verify if the mocked token was used by Azure SDK
|
||||
mocked_func_generating_token.assert_called()
|
||||
|
||||
|
||||
# asyncio.run(test_router_init())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue