mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
refactor(azure.py): working azure client init on audio speech endpoint
This commit is contained in:
parent
f7d9cce536
commit
152bc67d22
5 changed files with 63 additions and 46 deletions
|
@ -219,8 +219,12 @@ def test_select_azure_base_url_called(setup_mocks):
|
|||
CallTypes.acompletion,
|
||||
CallTypes.atext_completion,
|
||||
CallTypes.aembedding,
|
||||
CallTypes.arerank,
|
||||
CallTypes.atranscription,
|
||||
# CallTypes.arerank,
|
||||
# CallTypes.atranscription,
|
||||
CallTypes.aspeech,
|
||||
CallTypes.aimage_generation,
|
||||
# BATCHES ENDPOINTS
|
||||
# ASSISTANT ENDPOINTS
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
|
@ -255,15 +259,20 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
|||
"aembedding": {"input": "Hello, how are you?"},
|
||||
"arerank": {"input": "Hello, how are you?"},
|
||||
"atranscription": {"file": "path/to/file"},
|
||||
"aspeech": {"input": "Hello, how are you?", "voice": "female"},
|
||||
}
|
||||
|
||||
# Get appropriate input for this call type
|
||||
input_kwarg = test_inputs.get(call_type.value, {})
|
||||
|
||||
patch_target = "litellm.main.azure_chat_completions.initialize_azure_sdk_client"
|
||||
if call_type == CallTypes.atranscription:
|
||||
patch_target = (
|
||||
"litellm.main.azure_audio_transcriptions.initialize_azure_sdk_client"
|
||||
)
|
||||
|
||||
# Mock the initialize_azure_sdk_client function
|
||||
with patch(
|
||||
"litellm.main.azure_chat_completions.initialize_azure_sdk_client"
|
||||
) as mock_init_azure:
|
||||
with patch(patch_target) as mock_init_azure:
|
||||
# Also mock async_function_with_fallbacks to prevent actual API calls
|
||||
# Call the appropriate router method
|
||||
try:
|
||||
|
@ -271,6 +280,7 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
|||
model="gpt-3.5-turbo",
|
||||
**input_kwarg,
|
||||
num_retries=0,
|
||||
azure_ad_token="oidc/test-token",
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
@ -282,6 +292,16 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
|||
calls = mock_init_azure.call_args_list
|
||||
azure_calls = [call for call in calls]
|
||||
|
||||
litellm_params = azure_calls[0].kwargs["litellm_params"]
|
||||
print("litellm_params", litellm_params)
|
||||
|
||||
assert (
|
||||
"azure_ad_token" in litellm_params
|
||||
), "azure_ad_token not found in parameters"
|
||||
assert (
|
||||
litellm_params["azure_ad_token"] == "oidc/test-token"
|
||||
), "azure_ad_token is not correct"
|
||||
|
||||
# More detailed verification (optional)
|
||||
for call in azure_calls:
|
||||
assert "api_key" in call.kwargs, "api_key not found in parameters"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue