mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
get_openai_client_cache_key
This commit is contained in:
parent
40418c7bd8
commit
65083ca8da
4 changed files with 199 additions and 10 deletions
|
@ -263,10 +263,11 @@ class BaseAzureLLM(BaseOpenAILLM):
|
||||||
client_initialization_params=client_initialization_params,
|
client_initialization_params=client_initialization_params,
|
||||||
client_type="azure",
|
client_type="azure",
|
||||||
)
|
)
|
||||||
if cached_client and isinstance(
|
if cached_client:
|
||||||
cached_client, (AzureOpenAI, AsyncAzureOpenAI)
|
if isinstance(cached_client, AzureOpenAI) or isinstance(
|
||||||
):
|
cached_client, AsyncAzureOpenAI
|
||||||
return cached_client
|
):
|
||||||
|
return cached_client
|
||||||
|
|
||||||
azure_client_params = self.initialize_azure_sdk_client(
|
azure_client_params = self.initialize_azure_sdk_client(
|
||||||
litellm_params=litellm_params or {},
|
litellm_params=litellm_params or {},
|
||||||
|
|
|
@ -151,13 +151,23 @@ class BaseOpenAILLM:
|
||||||
f"is_async={client_initialization_params.get('is_async')}",
|
f"is_async={client_initialization_params.get('is_async')}",
|
||||||
]
|
]
|
||||||
|
|
||||||
for param in BaseOpenAILLM.get_openai_client_initialization_param_fields(
|
LITELLM_CLIENT_SPECIFIC_PARAMS = [
|
||||||
client_type=client_type
|
"timeout",
|
||||||
):
|
"max_retries",
|
||||||
|
"organization",
|
||||||
|
"api_base",
|
||||||
|
]
|
||||||
|
openai_client_fields = (
|
||||||
|
BaseOpenAILLM.get_openai_client_initialization_param_fields(
|
||||||
|
client_type=client_type
|
||||||
|
)
|
||||||
|
+ LITELLM_CLIENT_SPECIFIC_PARAMS
|
||||||
|
)
|
||||||
|
|
||||||
|
for param in openai_client_fields:
|
||||||
key_parts.append(f"{param}={client_initialization_params.get(param)}")
|
key_parts.append(f"{param}={client_initialization_params.get(param)}")
|
||||||
|
|
||||||
_cache_key = ",".join(key_parts)
|
_cache_key = ",".join(key_parts)
|
||||||
|
|
||||||
return _cache_key
|
return _cache_key
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -346,7 +346,7 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM):
|
||||||
max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
|
max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
||||||
):
|
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
|
||||||
client_initialization_params: Dict = locals()
|
client_initialization_params: Dict = locals()
|
||||||
if client is None:
|
if client is None:
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
|
@ -360,8 +360,12 @@ class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM):
|
||||||
client_initialization_params=client_initialization_params,
|
client_initialization_params=client_initialization_params,
|
||||||
client_type="openai",
|
client_type="openai",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cached_client:
|
if cached_client:
|
||||||
return cached_client
|
if isinstance(cached_client, OpenAI) or isinstance(
|
||||||
|
cached_client, AsyncOpenAI
|
||||||
|
):
|
||||||
|
return cached_client
|
||||||
if is_async:
|
if is_async:
|
||||||
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
|
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|
|
@ -461,3 +461,177 @@ async def test_ensure_initialize_azure_sdk_client_always_used_azure_text(call_ty
|
||||||
for call in azure_calls:
|
for call in azure_calls:
|
||||||
assert "api_key" in call.kwargs, "api_key not found in parameters"
|
assert "api_key" in call.kwargs, "api_key not found in parameters"
|
||||||
assert "api_base" in call.kwargs, "api_base not found in parameters"
|
assert "api_base" in call.kwargs, "api_base not found in parameters"
|
||||||
|
|
||||||
|
|
||||||
|
# Test parameters for different API functions with Azure models
|
||||||
|
AZURE_API_FUNCTION_PARAMS = [
|
||||||
|
# (function_name, is_async, args)
|
||||||
|
(
|
||||||
|
"completion",
|
||||||
|
False,
|
||||||
|
{
|
||||||
|
"model": "azure/gpt-4",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"api_base": "https://test.openai.azure.com",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"completion",
|
||||||
|
True,
|
||||||
|
{
|
||||||
|
"model": "azure/gpt-4",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"stream": True,
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"api_base": "https://test.openai.azure.com",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"embedding",
|
||||||
|
False,
|
||||||
|
{
|
||||||
|
"model": "azure/text-embedding-ada-002",
|
||||||
|
"input": "Hello world",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"api_base": "https://test.openai.azure.com",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"embedding",
|
||||||
|
True,
|
||||||
|
{
|
||||||
|
"model": "azure/text-embedding-ada-002",
|
||||||
|
"input": "Hello world",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"api_base": "https://test.openai.azure.com",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"speech",
|
||||||
|
False,
|
||||||
|
{
|
||||||
|
"model": "azure/tts-1",
|
||||||
|
"input": "Hello, this is a test of text to speech",
|
||||||
|
"voice": "alloy",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"api_base": "https://test.openai.azure.com",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"speech",
|
||||||
|
True,
|
||||||
|
{
|
||||||
|
"model": "azure/tts-1",
|
||||||
|
"input": "Hello, this is a test of text to speech",
|
||||||
|
"voice": "alloy",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"api_base": "https://test.openai.azure.com",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"transcription",
|
||||||
|
False,
|
||||||
|
{
|
||||||
|
"model": "azure/whisper-1",
|
||||||
|
"file": MagicMock(),
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"api_base": "https://test.openai.azure.com",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"transcription",
|
||||||
|
True,
|
||||||
|
{
|
||||||
|
"model": "azure/whisper-1",
|
||||||
|
"file": MagicMock(),
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
"api_base": "https://test.openai.azure.com",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("function_name,is_async,args", AZURE_API_FUNCTION_PARAMS)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_azure_client_reuse(function_name, is_async, args):
|
||||||
|
"""
|
||||||
|
Test that multiple Azure API calls reuse the same Azure OpenAI client
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
# Determine which client class to mock based on whether the test is async
|
||||||
|
client_path = (
|
||||||
|
"litellm.llms.azure.common_utils.AsyncAzureOpenAI"
|
||||||
|
if is_async
|
||||||
|
else "litellm.llms.azure.common_utils.AzureOpenAI"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a proper mock class that can pass isinstance checks
|
||||||
|
mock_client = MagicMock()
|
||||||
|
|
||||||
|
# Create the appropriate patches
|
||||||
|
with patch(client_path) as mock_client_class, patch.object(
|
||||||
|
BaseAzureLLM, "set_cached_openai_client"
|
||||||
|
) as mock_set_cache, patch.object(
|
||||||
|
BaseAzureLLM, "get_cached_openai_client"
|
||||||
|
) as mock_get_cache, patch.object(
|
||||||
|
BaseAzureLLM, "initialize_azure_sdk_client"
|
||||||
|
) as mock_init_azure:
|
||||||
|
# Configure the mock client class to return our mock instance
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Setup the mock to return None first time (cache miss) then a client for subsequent calls
|
||||||
|
mock_get_cache.side_effect = [None] + [
|
||||||
|
mock_client
|
||||||
|
] * 9 # First call returns None, rest return the mock client
|
||||||
|
|
||||||
|
# Mock the initialize_azure_sdk_client to return a dict with the necessary params
|
||||||
|
mock_init_azure.return_value = {
|
||||||
|
"api_key": args.get("api_key"),
|
||||||
|
"azure_endpoint": args.get("api_base"),
|
||||||
|
"api_version": args.get("api_version"),
|
||||||
|
"azure_ad_token": None,
|
||||||
|
"azure_ad_token_provider": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make 10 API calls
|
||||||
|
for _ in range(10):
|
||||||
|
try:
|
||||||
|
# Call the appropriate function based on parameters
|
||||||
|
if is_async:
|
||||||
|
# Add 'a' prefix for async functions
|
||||||
|
func = getattr(litellm, f"a{function_name}")
|
||||||
|
await func(**args)
|
||||||
|
else:
|
||||||
|
func = getattr(litellm, function_name)
|
||||||
|
func(**args)
|
||||||
|
except Exception:
|
||||||
|
# We expect exceptions since we're mocking the client
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify client was created only once
|
||||||
|
assert (
|
||||||
|
mock_client_class.call_count == 1
|
||||||
|
), f"{'Async' if is_async else ''}AzureOpenAI client should be created only once"
|
||||||
|
|
||||||
|
# Verify initialize_azure_sdk_client was called once
|
||||||
|
assert (
|
||||||
|
mock_init_azure.call_count == 1
|
||||||
|
), "initialize_azure_sdk_client should be called once"
|
||||||
|
|
||||||
|
# Verify the client was cached
|
||||||
|
assert mock_set_cache.call_count == 1, "Client should be cached once"
|
||||||
|
|
||||||
|
# Verify we tried to get from cache 10 times (once per request)
|
||||||
|
assert mock_get_cache.call_count == 10, "Should check cache for each request"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue