refactor(batches/main.py): working refactored azure client init on batches

This commit is contained in:
Krrish Dholakia 2025-03-11 14:36:38 -07:00
parent 9855e46208
commit 1516240bab
3 changed files with 51 additions and 19 deletions

View file

@ -219,11 +219,12 @@ def test_select_azure_base_url_called(setup_mocks):
CallTypes.acompletion,
CallTypes.atext_completion,
CallTypes.aembedding,
# CallTypes.arerank,
CallTypes.atranscription,
CallTypes.aspeech,
CallTypes.aimage_generation,
# BATCHES ENDPOINTS
CallTypes.acreate_batch,
CallTypes.aretrieve_batch,
# ASSISTANT ENDPOINTS
],
)
@ -260,6 +261,12 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
"arerank": {"input": "Hello, how are you?"},
"atranscription": {"file": "path/to/file"},
"aspeech": {"input": "Hello, how are you?", "voice": "female"},
"acreate_batch": {
"completion_window": 10,
"endpoint": "https://test.openai.azure.com",
"input_file_id": "123",
},
"aretrieve_batch": {"batch_id": "123"},
}
# Get appropriate input for this call type
@ -270,6 +277,14 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
patch_target = (
"litellm.main.azure_audio_transcriptions.initialize_azure_sdk_client"
)
elif call_type == CallTypes.arerank:
patch_target = (
"litellm.rerank_api.main.azure_rerank.initialize_azure_sdk_client"
)
elif call_type == CallTypes.acreate_batch or call_type == CallTypes.aretrieve_batch:
patch_target = (
"litellm.batches.main.azure_batches_instance.initialize_azure_sdk_client"
)
# Mock the initialize_azure_sdk_client function
with patch(patch_target) as mock_init_azure: