refactor(azure.py): refactor to have client init work across all endpoints

This commit is contained in:
Krrish Dholakia 2025-03-11 17:27:24 -07:00
parent 1516240bab
commit 2f262ed9b4
10 changed files with 296 additions and 129 deletions

View file

@ -216,16 +216,18 @@ def test_select_azure_base_url_called(setup_mocks):
@pytest.mark.parametrize(
"call_type",
[
CallTypes.acompletion,
CallTypes.atext_completion,
CallTypes.aembedding,
CallTypes.atranscription,
CallTypes.aspeech,
CallTypes.aimage_generation,
# BATCHES ENDPOINTS
CallTypes.acreate_batch,
CallTypes.aretrieve_batch,
# ASSISTANT ENDPOINTS
call_type
for call_type in CallTypes.__members__.values()
if call_type.name.startswith("a")
and call_type.name
not in [
"amoderation",
"arerank",
"arealtime",
"anthropic_messages",
"add_message",
"arun_thread_stream",
]
],
)
@pytest.mark.asyncio
@ -267,6 +269,28 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
"input_file_id": "123",
},
"aretrieve_batch": {"batch_id": "123"},
"aget_assistants": {"custom_llm_provider": "azure"},
"acreate_assistants": {"custom_llm_provider": "azure"},
"adelete_assistant": {"custom_llm_provider": "azure", "assistant_id": "123"},
"acreate_thread": {"custom_llm_provider": "azure"},
"aget_thread": {"custom_llm_provider": "azure", "thread_id": "123"},
"a_add_message": {
"custom_llm_provider": "azure",
"thread_id": "123",
"role": "user",
"content": "Hello, how are you?",
},
"aget_messages": {"custom_llm_provider": "azure", "thread_id": "123"},
"arun_thread": {
"custom_llm_provider": "azure",
"assistant_id": "123",
"thread_id": "123",
},
"acreate_file": {
"custom_llm_provider": "azure",
"file": MagicMock(),
"purpose": "assistants",
},
}
# Get appropriate input for this call type
@ -285,12 +309,34 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
patch_target = (
"litellm.batches.main.azure_batches_instance.initialize_azure_sdk_client"
)
elif (
call_type == CallTypes.aget_assistants
or call_type == CallTypes.acreate_assistants
or call_type == CallTypes.adelete_assistant
or call_type == CallTypes.acreate_thread
or call_type == CallTypes.aget_thread
or call_type == CallTypes.a_add_message
or call_type == CallTypes.aget_messages
or call_type == CallTypes.arun_thread
):
patch_target = (
"litellm.assistants.main.azure_assistants_api.initialize_azure_sdk_client"
)
elif call_type == CallTypes.acreate_file or call_type == CallTypes.afile_content:
patch_target = (
"litellm.files.main.azure_files_instance.initialize_azure_sdk_client"
)
# Mock the initialize_azure_sdk_client function
with patch(patch_target) as mock_init_azure:
# Also mock async_function_with_fallbacks to prevent actual API calls
# Call the appropriate router method
try:
get_attr = getattr(router, call_type.value, None)
if get_attr is None:
pytest.skip(
f"Skipping {call_type.value} because it is not supported on Router"
)
await getattr(router, call_type.value)(
model="gpt-3.5-turbo",
**input_kwarg,