diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 877c9f2bc2..4575942d58 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -31,6 +31,7 @@ from ...types.llms.openai import HttpxBinaryResponseContent from ..base import BaseLLM from .common_utils import ( AzureOpenAIError, + BaseAzureLLM, get_azure_ad_token_from_oidc, process_azure_headers, select_azure_base_url_or_endpoint, @@ -120,7 +121,7 @@ def _check_dynamic_azure_params( return False -class AzureChatCompletion(BaseLLM): +class AzureChatCompletion(BaseAzureLLM, BaseLLM): def __init__(self) -> None: super().__init__() @@ -348,6 +349,7 @@ class AzureChatCompletion(BaseLLM): logging_obj=logging_obj, max_retries=max_retries, convert_tool_call_to_json_mode=json_mode, + litellm_params=litellm_params, ) elif "stream" in optional_params and optional_params["stream"] is True: return self.streaming( @@ -476,29 +478,18 @@ class AzureChatCompletion(BaseLLM): azure_ad_token_provider: Optional[Callable] = None, convert_tool_call_to_json_mode: Optional[bool] = None, client=None, # this is the AsyncAzureOpenAI + litellm_params: Optional[dict] = None, ): response = None try: # init AzureOpenAI Client - azure_client_params = { - "api_version": api_version, - "azure_endpoint": api_base, - "azure_deployment": model, - "http_client": litellm.aclient_session, - "max_retries": max_retries, - "timeout": timeout, - } - azure_client_params = select_azure_base_url_or_endpoint( - azure_client_params=azure_client_params + azure_client_params = self.initialize_azure_sdk_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + model_name=model, + api_version=api_version, ) - if api_key is not None: - azure_client_params["api_key"] = api_key - elif azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - azure_client_params["azure_ad_token"] = azure_ad_token - elif azure_ad_token_provider is not None: - azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider # setting Azure client if client is None or dynamic_params: diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index 9d5bb76ea9..d70554c2d2 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -275,68 +275,73 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict): return azure_client_params -def initialize_azure_sdk_client( - litellm_params: dict, - api_key: Optional[str], - api_base: Optional[str], - model_name: str, - api_version: Optional[str], -) -> dict: - azure_ad_token_provider: Optional[Callable[[], str]] = None - # If we have api_key, then we have higher priority - azure_ad_token = litellm_params.get("azure_ad_token") - tenant_id = litellm_params.get("tenant_id") - client_id = litellm_params.get("client_id") - client_secret = litellm_params.get("client_secret") - azure_username = litellm_params.get("azure_username") - azure_password = litellm_params.get("azure_password") - if not api_key and tenant_id and client_id and client_secret: - verbose_logger.debug("Using Azure AD Token Provider for Azure Auth") - azure_ad_token_provider = get_azure_ad_token_from_entrata_id( - tenant_id=tenant_id, - client_id=client_id, - client_secret=client_secret, - ) - if azure_username and azure_password and client_id: - azure_ad_token_provider = get_azure_ad_token_from_username_password( - azure_username=azure_username, - azure_password=azure_password, - client_id=client_id, +class BaseAzureLLM: + def initialize_azure_sdk_client( + self, + litellm_params: dict, + api_key: Optional[str], + api_base: Optional[str], + model_name: str, + api_version: Optional[str], + ) -> dict: + + azure_ad_token_provider: Optional[Callable[[], str]] = None + # If we have api_key, then we have higher priority + azure_ad_token = litellm_params.get("azure_ad_token") + tenant_id = litellm_params.get("tenant_id") + client_id = litellm_params.get("client_id") + client_secret = litellm_params.get("client_secret") + azure_username = litellm_params.get("azure_username") + azure_password = litellm_params.get("azure_password") + if not api_key and tenant_id and client_id and client_secret: + verbose_logger.debug("Using Azure AD Token Provider for Azure Auth") + azure_ad_token_provider = get_azure_ad_token_from_entrata_id( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + ) + if azure_username and azure_password and client_id: + azure_ad_token_provider = get_azure_ad_token_from_username_password( + azure_username=azure_username, + azure_password=azure_password, + client_id=client_id, + ) + + if azure_ad_token is not None and azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + elif ( + not api_key + and azure_ad_token_provider is None + and litellm.enable_azure_ad_token_refresh is True + ): + try: + azure_ad_token_provider = get_azure_ad_token_provider() + except ValueError: + verbose_logger.debug("Azure AD Token Provider could not be used.") + if api_version is None: + api_version = os.getenv( + "AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION + ) + + _api_key = api_key + if _api_key is not None and isinstance(_api_key, str): + # only show first 5 chars of api_key + _api_key = _api_key[:8] + "*" * 15 + verbose_logger.debug( + f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" ) + azure_client_params = { + "api_key": api_key, + "azure_endpoint": api_base, + "api_version": api_version, + "azure_ad_token": azure_ad_token, + "azure_ad_token_provider": azure_ad_token_provider, + } - if azure_ad_token is not None and azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - elif ( - not api_key - and azure_ad_token_provider is None - and litellm.enable_azure_ad_token_refresh is True - ): - try: - azure_ad_token_provider = get_azure_ad_token_provider() - except ValueError: - verbose_logger.debug("Azure AD Token Provider could not be used.") - if api_version is None: - api_version = os.getenv("AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION) + if azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider + # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client + # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client + azure_client_params = select_azure_base_url_or_endpoint(azure_client_params) - _api_key = api_key - if _api_key is not None and isinstance(_api_key, str): - # only show first 5 chars of api_key - _api_key = _api_key[:8] + "*" * 15 - verbose_logger.debug( - f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" - ) - azure_client_params = { - "api_key": api_key, - "azure_endpoint": api_base, - "api_version": api_version, - "azure_ad_token": azure_ad_token, - "azure_ad_token_provider": azure_ad_token_provider, - } - - if azure_ad_token_provider is not None: - azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider - # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client - # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client - azure_client_params = select_azure_base_url_or_endpoint(azure_client_params) - - return azure_client_params + return azure_client_params diff --git a/tests/litellm/llms/azure/test_azure_common_utils.py b/tests/litellm/llms/azure/test_azure_common_utils.py index a2e7f78981..6f1d86450f 100644 --- a/tests/litellm/llms/azure/test_azure_common_utils.py +++ b/tests/litellm/llms/azure/test_azure_common_utils.py @@ -10,7 +10,8 @@ sys.path.insert( 0, os.path.abspath("../../../..") ) # Adds the parent directory to the system path import litellm -from litellm.llms.azure.common_utils import initialize_azure_sdk_client +from litellm.llms.azure.common_utils import BaseAzureLLM +from litellm.types.utils import CallTypes # Mock the necessary dependencies @@ -58,7 +59,7 @@ def setup_mocks(): def test_initialize_with_api_key(setup_mocks): # Test with api_key provided - result = initialize_azure_sdk_client( + result = BaseAzureLLM().initialize_azure_sdk_client( litellm_params={}, api_key="test-api-key", api_base="https://test.openai.azure.com", @@ -76,7 +77,7 @@ def test_initialize_with_api_key(setup_mocks): def test_initialize_with_tenant_credentials(setup_mocks): # Test with tenant_id, client_id, and client_secret provided - result = initialize_azure_sdk_client( + result = BaseAzureLLM().initialize_azure_sdk_client( litellm_params={ "tenant_id": "test-tenant-id", "client_id": "test-client-id", @@ -103,7 +104,7 @@ def test_initialize_with_tenant_credentials(setup_mocks): def test_initialize_with_username_password(setup_mocks): # Test with azure_username, azure_password, and client_id provided - result = initialize_azure_sdk_client( + result = BaseAzureLLM().initialize_azure_sdk_client( litellm_params={ "azure_username": "test-username", "azure_password": "test-password", @@ -128,7 +129,7 @@ def test_initialize_with_username_password(setup_mocks): def test_initialize_with_oidc_token(setup_mocks): # Test with azure_ad_token that starts with "oidc/" - result = initialize_azure_sdk_client( + result = BaseAzureLLM().initialize_azure_sdk_client( litellm_params={"azure_ad_token": "oidc/test-token"}, api_key=None, api_base="https://test.openai.azure.com", @@ -148,7 +149,7 @@ def test_initialize_with_enable_token_refresh(setup_mocks): setup_mocks["litellm"].enable_azure_ad_token_refresh = True # Test with token refresh enabled - result = initialize_azure_sdk_client( + result = BaseAzureLLM().initialize_azure_sdk_client( litellm_params={}, api_key=None, api_base="https://test.openai.azure.com", @@ -169,7 +170,7 @@ def test_initialize_with_token_refresh_error(setup_mocks): setup_mocks["token_provider"].side_effect = ValueError("Token provider error") # Test with token refresh enabled but raising error - result = initialize_azure_sdk_client( + result = BaseAzureLLM().initialize_azure_sdk_client( litellm_params={}, api_key=None, api_base="https://test.openai.azure.com", @@ -186,7 +187,7 @@ def test_initialize_with_token_refresh_error(setup_mocks): def test_api_version_from_env_var(setup_mocks): # Test api_version from environment variable with patch.dict(os.environ, {"AZURE_API_VERSION": "2023-07-01"}): - result = initialize_azure_sdk_client( + result = BaseAzureLLM().initialize_azure_sdk_client( litellm_params={}, api_key="test-api-key", api_base="https://test.openai.azure.com", @@ -200,7 +201,7 @@ def test_api_version_from_env_var(setup_mocks): def test_select_azure_base_url_called(setup_mocks): # Test that select_azure_base_url_or_endpoint is called - result = initialize_azure_sdk_client( + result = BaseAzureLLM().initialize_azure_sdk_client( litellm_params={}, api_key="test-api-key", api_base="https://test.openai.azure.com", @@ -210,3 +211,78 @@ def test_select_azure_base_url_called(setup_mocks): # Verify that select_azure_base_url_or_endpoint was called setup_mocks["select_url"].assert_called_once() + + +@pytest.mark.parametrize( + "call_type", + [ + CallTypes.acompletion, + CallTypes.atext_completion, + CallTypes.aembedding, + CallTypes.arerank, + CallTypes.atranscription, + ], +) +@pytest.mark.asyncio +async def test_ensure_initialize_azure_sdk_client_always_used(call_type): + from litellm.router import Router + + # Create a router with an Azure model + azure_model_name = "azure/chatgpt-v-2" + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": azure_model_name, + "api_key": "test-api-key", + "api_version": os.getenv("AZURE_API_VERSION", "2023-05-15"), + "api_base": os.getenv( + "AZURE_API_BASE", "https://test.openai.azure.com" + ), + }, + } + ], + ) + + # Prepare test input based on call type + test_inputs = { + "acompletion": { + "messages": [{"role": "user", "content": "Hello, how are you?"}] + }, + "atext_completion": {"prompt": "Hello, how are you?"}, + "aimage_generation": {"prompt": "Hello, how are you?"}, + "aembedding": {"input": "Hello, how are you?"}, + "arerank": {"input": "Hello, how are you?"}, + "atranscription": {"file": "path/to/file"}, + } + + # Get appropriate input for this call type + input_kwarg = test_inputs.get(call_type.value, {}) + + # Mock the initialize_azure_sdk_client function + with patch( + "litellm.main.azure_chat_completions.initialize_azure_sdk_client" + ) as mock_init_azure: + # Also mock async_function_with_fallbacks to prevent actual API calls + # Call the appropriate router method + try: + await getattr(router, call_type.value)( + model="gpt-3.5-turbo", + **input_kwarg, + num_retries=0, + ) + except Exception as e: + print(e) + + # Verify initialize_azure_sdk_client was called + mock_init_azure.assert_called_once() + + # Verify it was called with the right model name + calls = mock_init_azure.call_args_list + azure_calls = [call for call in calls] + + # More detailed verification (optional) + for call in azure_calls: + assert "api_key" in call.kwargs, "api_key not found in parameters" + assert "api_base" in call.kwargs, "api_base not found in parameters"