diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index 4d9c35a5fb..71092c8b99 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -263,10 +263,11 @@ class BaseAzureLLM(BaseOpenAILLM): client_initialization_params=client_initialization_params, client_type="azure", ) - if cached_client and isinstance( - cached_client, (AzureOpenAI, AsyncAzureOpenAI) - ): - return cached_client + if cached_client: + if isinstance(cached_client, AzureOpenAI) or isinstance( + cached_client, AsyncAzureOpenAI + ): + return cached_client azure_client_params = self.initialize_azure_sdk_client( litellm_params=litellm_params or {}, diff --git a/litellm/llms/openai/common_utils.py b/litellm/llms/openai/common_utils.py index f9ba366cb5..55da16d6cd 100644 --- a/litellm/llms/openai/common_utils.py +++ b/litellm/llms/openai/common_utils.py @@ -151,13 +151,23 @@ class BaseOpenAILLM: f"is_async={client_initialization_params.get('is_async')}", ] - for param in BaseOpenAILLM.get_openai_client_initialization_param_fields( - client_type=client_type - ): + LITELLM_CLIENT_SPECIFIC_PARAMS = [ + "timeout", + "max_retries", + "organization", + "api_base", + ] + openai_client_fields = ( + BaseOpenAILLM.get_openai_client_initialization_param_fields( + client_type=client_type + ) + + LITELLM_CLIENT_SPECIFIC_PARAMS + ) + + for param in openai_client_fields: key_parts.append(f"{param}={client_initialization_params.get(param)}") _cache_key = ",".join(key_parts) - return _cache_key @staticmethod diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index 475c83be34..98ef95239e 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -346,7 +346,7 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): max_retries: Optional[int] = DEFAULT_MAX_RETRIES, organization: Optional[str] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None, - ): + ) -> Optional[Union[OpenAI, AsyncOpenAI]]: client_initialization_params: Dict = locals() if client is None: if not isinstance(max_retries, int): @@ -360,8 +360,12 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): client_initialization_params=client_initialization_params, client_type="openai", ) + if cached_client: - return cached_client + if isinstance(cached_client, OpenAI) or isinstance( + cached_client, AsyncOpenAI + ): + return cached_client if is_async: _new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI( api_key=api_key, diff --git a/tests/litellm/llms/azure/test_azure_common_utils.py b/tests/litellm/llms/azure/test_azure_common_utils.py index 4d24009685..a9e63f84f2 100644 --- a/tests/litellm/llms/azure/test_azure_common_utils.py +++ b/tests/litellm/llms/azure/test_azure_common_utils.py @@ -461,3 +461,177 @@ async def test_ensure_initialize_azure_sdk_client_always_used_azure_text(call_ty for call in azure_calls: assert "api_key" in call.kwargs, "api_key not found in parameters" assert "api_base" in call.kwargs, "api_base not found in parameters" + + +# Test parameters for different API functions with Azure models +AZURE_API_FUNCTION_PARAMS = [ + # (function_name, is_async, args) + ( + "completion", + False, + { + "model": "azure/gpt-4", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + "api_key": "test-api-key", + "api_base": "https://test.openai.azure.com", + "api_version": "2023-05-15", + }, + ), + ( + "completion", + True, + { + "model": "azure/gpt-4", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + "stream": True, + "api_key": "test-api-key", + "api_base": "https://test.openai.azure.com", + "api_version": "2023-05-15", + }, + ), + ( + "embedding", + False, + { + "model": "azure/text-embedding-ada-002", + "input": "Hello world", + "api_key": "test-api-key", + "api_base": "https://test.openai.azure.com", + "api_version": "2023-05-15", + }, + ), + ( + "embedding", + True, + { + "model": "azure/text-embedding-ada-002", + "input": "Hello world", + "api_key": "test-api-key", + "api_base": "https://test.openai.azure.com", + "api_version": "2023-05-15", + }, + ), + ( + "speech", + False, + { + "model": "azure/tts-1", + "input": "Hello, this is a test of text to speech", + "voice": "alloy", + "api_key": "test-api-key", + "api_base": "https://test.openai.azure.com", + "api_version": "2023-05-15", + }, + ), + ( + "speech", + True, + { + "model": "azure/tts-1", + "input": "Hello, this is a test of text to speech", + "voice": "alloy", + "api_key": "test-api-key", + "api_base": "https://test.openai.azure.com", + "api_version": "2023-05-15", + }, + ), + ( + "transcription", + False, + { + "model": "azure/whisper-1", + "file": MagicMock(), + "api_key": "test-api-key", + "api_base": "https://test.openai.azure.com", + "api_version": "2023-05-15", + }, + ), + ( + "transcription", + True, + { + "model": "azure/whisper-1", + "file": MagicMock(), + "api_key": "test-api-key", + "api_base": "https://test.openai.azure.com", + "api_version": "2023-05-15", + }, + ), +] + + +@pytest.mark.parametrize("function_name,is_async,args", AZURE_API_FUNCTION_PARAMS) +@pytest.mark.asyncio +async def test_azure_client_reuse(function_name, is_async, args): + """ + Test that multiple Azure API calls reuse the same Azure OpenAI client + """ + litellm.set_verbose = True + + # Determine which client class to mock based on whether the test is async + client_path = ( + "litellm.llms.azure.common_utils.AsyncAzureOpenAI" + if is_async + else "litellm.llms.azure.common_utils.AzureOpenAI" + ) + + # Create a proper mock class that can pass isinstance checks + mock_client = MagicMock() + + # Create the appropriate patches + with patch(client_path) as mock_client_class, patch.object( + BaseAzureLLM, "set_cached_openai_client" + ) as mock_set_cache, patch.object( + BaseAzureLLM, "get_cached_openai_client" + ) as mock_get_cache, patch.object( + BaseAzureLLM, "initialize_azure_sdk_client" + ) as mock_init_azure: + # Configure the mock client class to return our mock instance + mock_client_class.return_value = mock_client + + # Setup the mock to return None first time (cache miss) then a client for subsequent calls + mock_get_cache.side_effect = [None] + [ + mock_client + ] * 9 # First call returns None, rest return the mock client + + # Mock the initialize_azure_sdk_client to return a dict with the necessary params + mock_init_azure.return_value = { + "api_key": args.get("api_key"), + "azure_endpoint": args.get("api_base"), + "api_version": args.get("api_version"), + "azure_ad_token": None, + "azure_ad_token_provider": None, + } + + # Make 10 API calls + for _ in range(10): + try: + # Call the appropriate function based on parameters + if is_async: + # Add 'a' prefix for async functions + func = getattr(litellm, f"a{function_name}") + await func(**args) + else: + func = getattr(litellm, function_name) + func(**args) + except Exception: + # We expect exceptions since we're mocking the client + pass + + # Verify client was created only once + assert ( + mock_client_class.call_count == 1 + ), f"{'Async' if is_async else ''}AzureOpenAI client should be created only once" + + # Verify initialize_azure_sdk_client was called once + assert ( + mock_init_azure.call_count == 1 + ), "initialize_azure_sdk_client should be called once" + + # Verify the client was cached + assert mock_set_cache.call_count == 1, "Client should be cached once" + + # Verify we tried to get from cache 10 times (once per request) + assert mock_get_cache.call_count == 10, "Should check cache for each request"