get_openai_client_cache_key

This commit is contained in:
Ishaan Jaff 2025-03-18 18:35:50 -07:00
parent bb8400a350
commit 55e669d7d8
4 changed files with 199 additions and 10 deletions

View file

@ -263,10 +263,11 @@ class BaseAzureLLM(BaseOpenAILLM):
client_initialization_params=client_initialization_params, client_initialization_params=client_initialization_params,
client_type="azure", client_type="azure",
) )
if cached_client and isinstance( if cached_client:
cached_client, (AzureOpenAI, AsyncAzureOpenAI) if isinstance(cached_client, AzureOpenAI) or isinstance(
): cached_client, AsyncAzureOpenAI
return cached_client ):
return cached_client
azure_client_params = self.initialize_azure_sdk_client( azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {}, litellm_params=litellm_params or {},

View file

@ -151,13 +151,23 @@ class BaseOpenAILLM:
f"is_async={client_initialization_params.get('is_async')}", f"is_async={client_initialization_params.get('is_async')}",
] ]
for param in BaseOpenAILLM.get_openai_client_initialization_param_fields( LITELLM_CLIENT_SPECIFIC_PARAMS = [
client_type=client_type "timeout",
): "max_retries",
"organization",
"api_base",
]
openai_client_fields = (
BaseOpenAILLM.get_openai_client_initialization_param_fields(
client_type=client_type
)
+ LITELLM_CLIENT_SPECIFIC_PARAMS
)
for param in openai_client_fields:
key_parts.append(f"{param}={client_initialization_params.get(param)}") key_parts.append(f"{param}={client_initialization_params.get(param)}")
_cache_key = ",".join(key_parts) _cache_key = ",".join(key_parts)
return _cache_key return _cache_key
@staticmethod @staticmethod

View file

@ -346,7 +346,7 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM):
max_retries: Optional[int] = DEFAULT_MAX_RETRIES, max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
organization: Optional[str] = None, organization: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
): ) -> Optional[Union[OpenAI, AsyncOpenAI]]:
client_initialization_params: Dict = locals() client_initialization_params: Dict = locals()
if client is None: if client is None:
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
@ -360,8 +360,12 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM):
client_initialization_params=client_initialization_params, client_initialization_params=client_initialization_params,
client_type="openai", client_type="openai",
) )
if cached_client: if cached_client:
return cached_client if isinstance(cached_client, OpenAI) or isinstance(
cached_client, AsyncOpenAI
):
return cached_client
if is_async: if is_async:
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI( _new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
api_key=api_key, api_key=api_key,

View file

@ -461,3 +461,177 @@ async def test_ensure_initialize_azure_sdk_client_always_used_azure_text(call_ty
for call in azure_calls: for call in azure_calls:
assert "api_key" in call.kwargs, "api_key not found in parameters" assert "api_key" in call.kwargs, "api_key not found in parameters"
assert "api_base" in call.kwargs, "api_base not found in parameters" assert "api_base" in call.kwargs, "api_base not found in parameters"
# Test parameters for different API functions with Azure models
AZURE_API_FUNCTION_PARAMS = [
# (function_name, is_async, args)
(
"completion",
False,
{
"model": "azure/gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
},
),
(
"completion",
True,
{
"model": "azure/gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
"stream": True,
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
},
),
(
"embedding",
False,
{
"model": "azure/text-embedding-ada-002",
"input": "Hello world",
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
},
),
(
"embedding",
True,
{
"model": "azure/text-embedding-ada-002",
"input": "Hello world",
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
},
),
(
"speech",
False,
{
"model": "azure/tts-1",
"input": "Hello, this is a test of text to speech",
"voice": "alloy",
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
},
),
(
"speech",
True,
{
"model": "azure/tts-1",
"input": "Hello, this is a test of text to speech",
"voice": "alloy",
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
},
),
(
"transcription",
False,
{
"model": "azure/whisper-1",
"file": MagicMock(),
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
},
),
(
"transcription",
True,
{
"model": "azure/whisper-1",
"file": MagicMock(),
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
},
),
]
@pytest.mark.parametrize("function_name,is_async,args", AZURE_API_FUNCTION_PARAMS)
@pytest.mark.asyncio
async def test_azure_client_reuse(function_name, is_async, args):
"""
Test that multiple Azure API calls reuse the same Azure OpenAI client
"""
litellm.set_verbose = True
# Determine which client class to mock based on whether the test is async
client_path = (
"litellm.llms.azure.common_utils.AsyncAzureOpenAI"
if is_async
else "litellm.llms.azure.common_utils.AzureOpenAI"
)
# Create a proper mock class that can pass isinstance checks
mock_client = MagicMock()
# Create the appropriate patches
with patch(client_path) as mock_client_class, patch.object(
BaseAzureLLM, "set_cached_openai_client"
) as mock_set_cache, patch.object(
BaseAzureLLM, "get_cached_openai_client"
) as mock_get_cache, patch.object(
BaseAzureLLM, "initialize_azure_sdk_client"
) as mock_init_azure:
# Configure the mock client class to return our mock instance
mock_client_class.return_value = mock_client
# Setup the mock to return None first time (cache miss) then a client for subsequent calls
mock_get_cache.side_effect = [None] + [
mock_client
] * 9 # First call returns None, rest return the mock client
# Mock the initialize_azure_sdk_client to return a dict with the necessary params
mock_init_azure.return_value = {
"api_key": args.get("api_key"),
"azure_endpoint": args.get("api_base"),
"api_version": args.get("api_version"),
"azure_ad_token": None,
"azure_ad_token_provider": None,
}
# Make 10 API calls
for _ in range(10):
try:
# Call the appropriate function based on parameters
if is_async:
# Add 'a' prefix for async functions
func = getattr(litellm, f"a{function_name}")
await func(**args)
else:
func = getattr(litellm, function_name)
func(**args)
except Exception:
# We expect exceptions since we're mocking the client
pass
# Verify client was created only once
assert (
mock_client_class.call_count == 1
), f"{'Async' if is_async else ''}AzureOpenAI client should be created only once"
# Verify initialize_azure_sdk_client was called once
assert (
mock_init_azure.call_count == 1
), "initialize_azure_sdk_client should be called once"
# Verify the client was cached
assert mock_set_cache.call_count == 1, "Client should be cached once"
# Verify we tried to get from cache 10 times (once per request)
assert mock_get_cache.call_count == 10, "Should check cache for each request"