From 06b297a6e812bd6fb6da0426c1b01203eb4c61d3 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 21 Jun 2024 17:08:54 -0700 Subject: [PATCH] fix(router.py): fix set_client init to check if custom_llm_provider is azure not if in model name fixes issue where 'azure_ai/' was being init as azureopenai client --- litellm/router.py | 4 ++-- litellm/tests/test_router.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 284cb3203..d7e2aa12f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -3090,7 +3090,7 @@ class Router: if not, add it - https://github.com/BerriAI/litellm/issues/2279 """ if ( - is_azure_ai_studio_model == True + is_azure_ai_studio_model is True and api_base is not None and isinstance(api_base, str) and not api_base.endswith("/v1/") @@ -3174,7 +3174,7 @@ class Router: organization = litellm.get_secret(organization_env_name) litellm_params["organization"] = organization - if "azure" in model_name: + if custom_llm_provider == "azure" or custom_llm_provider == "azure_text": if api_base is None or not isinstance(api_base, str): filtered_litellm_params = { k: v diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index b84bc49d9..d2037dc59 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -70,6 +70,31 @@ def test_router_specific_model_via_id(): router.completion(model="1234", messages=[{"role": "user", "content": "Hey!"}]) +def test_router_azure_ai_client_init(): + + _deployment = { + "model_name": "meta-llama-3-70b", + "litellm_params": { + "model": "azure_ai/Meta-Llama-3-70B-instruct", + "api_base": "my-fake-route", + "api_key": "my-fake-key", + }, + "model_info": {"id": "1234"}, + } + router = Router(model_list=[_deployment]) + + _client = router._get_client( + deployment=_deployment, + client_type="async", + kwargs={"stream": False}, + ) + print(_client) + from openai import AsyncAzureOpenAI, AsyncOpenAI + + assert isinstance(_client, AsyncOpenAI) + assert not isinstance(_client, AsyncAzureOpenAI) + + def test_router_sensitive_keys(): try: router = Router(