mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
refactor(azure.py): refactor to have client init work across all endpoints
This commit is contained in:
parent
d99d60a182
commit
cbc2e84044
10 changed files with 296 additions and 129 deletions
|
@ -216,16 +216,18 @@ def test_select_azure_base_url_called(setup_mocks):
|
|||
@pytest.mark.parametrize(
|
||||
"call_type",
|
||||
[
|
||||
CallTypes.acompletion,
|
||||
CallTypes.atext_completion,
|
||||
CallTypes.aembedding,
|
||||
CallTypes.atranscription,
|
||||
CallTypes.aspeech,
|
||||
CallTypes.aimage_generation,
|
||||
# BATCHES ENDPOINTS
|
||||
CallTypes.acreate_batch,
|
||||
CallTypes.aretrieve_batch,
|
||||
# ASSISTANT ENDPOINTS
|
||||
call_type
|
||||
for call_type in CallTypes.__members__.values()
|
||||
if call_type.name.startswith("a")
|
||||
and call_type.name
|
||||
not in [
|
||||
"amoderation",
|
||||
"arerank",
|
||||
"arealtime",
|
||||
"anthropic_messages",
|
||||
"add_message",
|
||||
"arun_thread_stream",
|
||||
]
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
|
@ -267,6 +269,28 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
|||
"input_file_id": "123",
|
||||
},
|
||||
"aretrieve_batch": {"batch_id": "123"},
|
||||
"aget_assistants": {"custom_llm_provider": "azure"},
|
||||
"acreate_assistants": {"custom_llm_provider": "azure"},
|
||||
"adelete_assistant": {"custom_llm_provider": "azure", "assistant_id": "123"},
|
||||
"acreate_thread": {"custom_llm_provider": "azure"},
|
||||
"aget_thread": {"custom_llm_provider": "azure", "thread_id": "123"},
|
||||
"a_add_message": {
|
||||
"custom_llm_provider": "azure",
|
||||
"thread_id": "123",
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?",
|
||||
},
|
||||
"aget_messages": {"custom_llm_provider": "azure", "thread_id": "123"},
|
||||
"arun_thread": {
|
||||
"custom_llm_provider": "azure",
|
||||
"assistant_id": "123",
|
||||
"thread_id": "123",
|
||||
},
|
||||
"acreate_file": {
|
||||
"custom_llm_provider": "azure",
|
||||
"file": MagicMock(),
|
||||
"purpose": "assistants",
|
||||
},
|
||||
}
|
||||
|
||||
# Get appropriate input for this call type
|
||||
|
@ -285,12 +309,34 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
|||
patch_target = (
|
||||
"litellm.batches.main.azure_batches_instance.initialize_azure_sdk_client"
|
||||
)
|
||||
elif (
|
||||
call_type == CallTypes.aget_assistants
|
||||
or call_type == CallTypes.acreate_assistants
|
||||
or call_type == CallTypes.adelete_assistant
|
||||
or call_type == CallTypes.acreate_thread
|
||||
or call_type == CallTypes.aget_thread
|
||||
or call_type == CallTypes.a_add_message
|
||||
or call_type == CallTypes.aget_messages
|
||||
or call_type == CallTypes.arun_thread
|
||||
):
|
||||
patch_target = (
|
||||
"litellm.assistants.main.azure_assistants_api.initialize_azure_sdk_client"
|
||||
)
|
||||
elif call_type == CallTypes.acreate_file or call_type == CallTypes.afile_content:
|
||||
patch_target = (
|
||||
"litellm.files.main.azure_files_instance.initialize_azure_sdk_client"
|
||||
)
|
||||
|
||||
# Mock the initialize_azure_sdk_client function
|
||||
with patch(patch_target) as mock_init_azure:
|
||||
# Also mock async_function_with_fallbacks to prevent actual API calls
|
||||
# Call the appropriate router method
|
||||
try:
|
||||
get_attr = getattr(router, call_type.value, None)
|
||||
if get_attr is None:
|
||||
pytest.skip(
|
||||
f"Skipping {call_type.value} because it is not supported on Router"
|
||||
)
|
||||
await getattr(router, call_type.value)(
|
||||
model="gpt-3.5-turbo",
|
||||
**input_kwarg,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue