mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
get_openai_client_cache_key
This commit is contained in:
parent
40418c7bd8
commit
65083ca8da
4 changed files with 199 additions and 10 deletions
|
@ -263,8 +263,9 @@ class BaseAzureLLM(BaseOpenAILLM):
|
|||
client_initialization_params=client_initialization_params,
|
||||
client_type="azure",
|
||||
)
|
||||
if cached_client and isinstance(
|
||||
cached_client, (AzureOpenAI, AsyncAzureOpenAI)
|
||||
if cached_client:
|
||||
if isinstance(cached_client, AzureOpenAI) or isinstance(
|
||||
cached_client, AsyncAzureOpenAI
|
||||
):
|
||||
return cached_client
|
||||
|
||||
|
|
|
@ -151,13 +151,23 @@ class BaseOpenAILLM:
|
|||
f"is_async={client_initialization_params.get('is_async')}",
|
||||
]
|
||||
|
||||
for param in BaseOpenAILLM.get_openai_client_initialization_param_fields(
|
||||
LITELLM_CLIENT_SPECIFIC_PARAMS = [
|
||||
"timeout",
|
||||
"max_retries",
|
||||
"organization",
|
||||
"api_base",
|
||||
]
|
||||
openai_client_fields = (
|
||||
BaseOpenAILLM.get_openai_client_initialization_param_fields(
|
||||
client_type=client_type
|
||||
):
|
||||
)
|
||||
+ LITELLM_CLIENT_SPECIFIC_PARAMS
|
||||
)
|
||||
|
||||
for param in openai_client_fields:
|
||||
key_parts.append(f"{param}={client_initialization_params.get(param)}")
|
||||
|
||||
_cache_key = ",".join(key_parts)
|
||||
|
||||
return _cache_key
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -346,7 +346,7 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM):
|
|||
max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
|
||||
organization: Optional[str] = None,
|
||||
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
||||
):
|
||||
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
|
||||
client_initialization_params: Dict = locals()
|
||||
if client is None:
|
||||
if not isinstance(max_retries, int):
|
||||
|
@ -360,7 +360,11 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM):
|
|||
client_initialization_params=client_initialization_params,
|
||||
client_type="openai",
|
||||
)
|
||||
|
||||
if cached_client:
|
||||
if isinstance(cached_client, OpenAI) or isinstance(
|
||||
cached_client, AsyncOpenAI
|
||||
):
|
||||
return cached_client
|
||||
if is_async:
|
||||
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
|
||||
|
|
|
@ -461,3 +461,177 @@ async def test_ensure_initialize_azure_sdk_client_always_used_azure_text(call_ty
|
|||
for call in azure_calls:
|
||||
assert "api_key" in call.kwargs, "api_key not found in parameters"
|
||||
assert "api_base" in call.kwargs, "api_base not found in parameters"
|
||||
|
||||
|
||||
# Test parameters for different API functions with Azure models
|
||||
AZURE_API_FUNCTION_PARAMS = [
|
||||
# (function_name, is_async, args)
|
||||
(
|
||||
"completion",
|
||||
False,
|
||||
{
|
||||
"model": "azure/gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com",
|
||||
"api_version": "2023-05-15",
|
||||
},
|
||||
),
|
||||
(
|
||||
"completion",
|
||||
True,
|
||||
{
|
||||
"model": "azure/gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
"stream": True,
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com",
|
||||
"api_version": "2023-05-15",
|
||||
},
|
||||
),
|
||||
(
|
||||
"embedding",
|
||||
False,
|
||||
{
|
||||
"model": "azure/text-embedding-ada-002",
|
||||
"input": "Hello world",
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com",
|
||||
"api_version": "2023-05-15",
|
||||
},
|
||||
),
|
||||
(
|
||||
"embedding",
|
||||
True,
|
||||
{
|
||||
"model": "azure/text-embedding-ada-002",
|
||||
"input": "Hello world",
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com",
|
||||
"api_version": "2023-05-15",
|
||||
},
|
||||
),
|
||||
(
|
||||
"speech",
|
||||
False,
|
||||
{
|
||||
"model": "azure/tts-1",
|
||||
"input": "Hello, this is a test of text to speech",
|
||||
"voice": "alloy",
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com",
|
||||
"api_version": "2023-05-15",
|
||||
},
|
||||
),
|
||||
(
|
||||
"speech",
|
||||
True,
|
||||
{
|
||||
"model": "azure/tts-1",
|
||||
"input": "Hello, this is a test of text to speech",
|
||||
"voice": "alloy",
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com",
|
||||
"api_version": "2023-05-15",
|
||||
},
|
||||
),
|
||||
(
|
||||
"transcription",
|
||||
False,
|
||||
{
|
||||
"model": "azure/whisper-1",
|
||||
"file": MagicMock(),
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com",
|
||||
"api_version": "2023-05-15",
|
||||
},
|
||||
),
|
||||
(
|
||||
"transcription",
|
||||
True,
|
||||
{
|
||||
"model": "azure/whisper-1",
|
||||
"file": MagicMock(),
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com",
|
||||
"api_version": "2023-05-15",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("function_name,is_async,args", AZURE_API_FUNCTION_PARAMS)
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_client_reuse(function_name, is_async, args):
|
||||
"""
|
||||
Test that multiple Azure API calls reuse the same Azure OpenAI client
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Determine which client class to mock based on whether the test is async
|
||||
client_path = (
|
||||
"litellm.llms.azure.common_utils.AsyncAzureOpenAI"
|
||||
if is_async
|
||||
else "litellm.llms.azure.common_utils.AzureOpenAI"
|
||||
)
|
||||
|
||||
# Create a proper mock class that can pass isinstance checks
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Create the appropriate patches
|
||||
with patch(client_path) as mock_client_class, patch.object(
|
||||
BaseAzureLLM, "set_cached_openai_client"
|
||||
) as mock_set_cache, patch.object(
|
||||
BaseAzureLLM, "get_cached_openai_client"
|
||||
) as mock_get_cache, patch.object(
|
||||
BaseAzureLLM, "initialize_azure_sdk_client"
|
||||
) as mock_init_azure:
|
||||
# Configure the mock client class to return our mock instance
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Setup the mock to return None first time (cache miss) then a client for subsequent calls
|
||||
mock_get_cache.side_effect = [None] + [
|
||||
mock_client
|
||||
] * 9 # First call returns None, rest return the mock client
|
||||
|
||||
# Mock the initialize_azure_sdk_client to return a dict with the necessary params
|
||||
mock_init_azure.return_value = {
|
||||
"api_key": args.get("api_key"),
|
||||
"azure_endpoint": args.get("api_base"),
|
||||
"api_version": args.get("api_version"),
|
||||
"azure_ad_token": None,
|
||||
"azure_ad_token_provider": None,
|
||||
}
|
||||
|
||||
# Make 10 API calls
|
||||
for _ in range(10):
|
||||
try:
|
||||
# Call the appropriate function based on parameters
|
||||
if is_async:
|
||||
# Add 'a' prefix for async functions
|
||||
func = getattr(litellm, f"a{function_name}")
|
||||
await func(**args)
|
||||
else:
|
||||
func = getattr(litellm, function_name)
|
||||
func(**args)
|
||||
except Exception:
|
||||
# We expect exceptions since we're mocking the client
|
||||
pass
|
||||
|
||||
# Verify client was created only once
|
||||
assert (
|
||||
mock_client_class.call_count == 1
|
||||
), f"{'Async' if is_async else ''}AzureOpenAI client should be created only once"
|
||||
|
||||
# Verify initialize_azure_sdk_client was called once
|
||||
assert (
|
||||
mock_init_azure.call_count == 1
|
||||
), "initialize_azure_sdk_client should be called once"
|
||||
|
||||
# Verify the client was cached
|
||||
assert mock_set_cache.call_count == 1, "Client should be cached once"
|
||||
|
||||
# Verify we tried to get from cache 10 times (once per request)
|
||||
assert mock_get_cache.call_count == 10, "Should check cache for each request"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue