From 5c1a662caa7b820907110dbb0a4e93b3a7ed529d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 5 Apr 2024 14:35:31 -0700 Subject: [PATCH] proxy - add azure/command r --- litellm/router.py | 5 ++++ litellm/tests/test_router_init.py | 43 +++++++++++++++++++++++++++++++ litellm/utils.py | 20 +++++++++----- 3 files changed, 62 insertions(+), 6 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index a34d90e8aa..bf3ae66755 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1767,6 +1767,11 @@ class Router: or "ft:gpt-3.5-turbo" in model_name or model_name in litellm.open_ai_embedding_models ): + if custom_llm_provider == "azure": + if litellm.utils._is_non_openai_azure_model(model_name): + custom_llm_provider = "openai" + # remove azure prefx from model_name + model_name = model_name.replace("azure/", "") # glorified / complicated reading of configs # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env # we do this here because we init clients for Azure, OpenAI and we need to set the right key diff --git a/litellm/tests/test_router_init.py b/litellm/tests/test_router_init.py index 5fa1420538..9953a8d3e8 100644 --- a/litellm/tests/test_router_init.py +++ b/litellm/tests/test_router_init.py @@ -440,3 +440,46 @@ def test_openai_with_organization(): except Exception as e: pytest.fail(f"Error occurred: {e}") + + +def test_init_clients_azure_command_r_plus(): + # This tests that the router uses the OpenAI client for Azure/Command-R+ + # For azure/command-r-plus we need to use openai.OpenAI because of how the Azure provider requires requests being sent + litellm.set_verbose = True + import logging + from litellm._logging import verbose_router_logger + + verbose_router_logger.setLevel(logging.DEBUG) + try: + print("testing init 4 clients with diff timeouts") + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/command-r-plus", + "api_key": os.getenv("AZURE_COHERE_API_KEY"), + "api_base": os.getenv("AZURE_COHERE_API_BASE"), + "timeout": 0.01, + "stream_timeout": 0.000_001, + "max_retries": 7, + }, + }, + ] + router = Router(model_list=model_list, set_verbose=True) + for elem in router.model_list: + model_id = elem["model_info"]["id"] + async_client = router.cache.get_cache(f"{model_id}_async_client") + stream_async_client = router.cache.get_cache( + f"{model_id}_stream_async_client" + ) + # Assert the Async Clients used are OpenAI clients and not Azure + # For using Azure/Command-R-Plus and Azure/Mistral the clients NEED to be OpenAI clients used + # this is weirdness introduced on Azure's side + + assert "openai.AsyncOpenAI" in str(async_client) + assert "openai.AsyncOpenAI" in str(stream_async_client) + print("PASSED !") + + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index 6c0521265b..712d70eeeb 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5563,6 +5563,19 @@ def get_formatted_prompt( return prompt +def _is_non_openai_azure_model(model: str) -> bool: + try: + model_name = model.split("/", 1)[1] + if ( + model_name in litellm.cohere_chat_models + or f"mistral/{model_name}" in litellm.mistral_chat_models + ): + return True + except: + return False + return False + + def get_llm_provider( model: str, custom_llm_provider: Optional[str] = None, @@ -5576,13 +5589,8 @@ def get_llm_provider( # AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere # If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus if model.split("/", 1)[0] == "azure": - model_name = model.split("/", 1)[1] - if ( - model_name in litellm.cohere_chat_models - or f"mistral/{model_name}" in litellm.mistral_chat_models - ): + if _is_non_openai_azure_model(model): custom_llm_provider = "openai" - model = model_name return model, custom_llm_provider, dynamic_api_key, api_base if custom_llm_provider: