From 069d1f863d16e6e8f1ad539748d05ef3f8d8cb98 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 26 Apr 2024 17:05:07 -0700 Subject: [PATCH] fix(router.py): add `/v1/` if missing to base url, for openai-compatible api's Fixes https://github.com/BerriAI/litellm/issues/2279 --- litellm/proxy/_super_secret_config.yaml | 5 ++++ litellm/router.py | 18 +++++++++++++ litellm/tests/test_router.py | 36 +++++++++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 0ea72c85b1..89827df7c1 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -13,6 +13,11 @@ model_list: - litellm_params: model: gpt-4 model_name: gpt-4 +- model_name: azure-mistral + litellm_params: + model: azure/mistral-large-latest + api_base: https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com + api_key: os.environ/AZURE_MISTRAL_API_KEY # litellm_settings: # cache: True \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index 8cb0f3ed2a..f84b2eab02 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1929,6 +1929,7 @@ class Router: ) default_api_base = api_base default_api_key = api_key + if ( model_name in litellm.open_ai_chat_completion_models or custom_llm_provider in litellm.openai_compatible_providers @@ -1964,6 +1965,23 @@ class Router: api_base = litellm.get_secret(api_base_env_name) litellm_params["api_base"] = api_base + ## AZURE AI STUDIO MISTRAL CHECK ## + """ + Make sure api base ends in /v1/ + + if not, add it - https://github.com/BerriAI/litellm/issues/2279 + """ + if ( + custom_llm_provider == "openai" + and api_base is not None + and not api_base.endswith("/v1/") + ): + # check if it ends with a trailing slash + if api_base.endswith("/"): + api_base += "v1/" + else: + api_base += "/v1/" + api_version = litellm_params.get("api_version") if api_version and api_version.startswith("os.environ/"): api_version_env_name = api_version.replace("os.environ/", "") diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index fd95083b7b..914e36da1a 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -65,6 +65,42 @@ def test_router_timeout_init(timeout, ssl_verify): ) +@pytest.mark.parametrize( + "mistral_api_base", + [ + "os.environ/AZURE_MISTRAL_API_BASE", + "https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1/", + "https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/v1", + "https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com/", + "https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com", + ], +) +def test_router_azure_ai_studio_init(mistral_api_base): + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "azure/mistral-large-latest", + "api_key": "os.environ/AZURE_MISTRAL_API_KEY", + "api_base": mistral_api_base, + }, + "model_info": {"id": 1234}, + } + ] + ) + + model_client = router._get_client( + deployment={"model_info": {"id": 1234}}, client_type="sync_client", kwargs={} + ) + url = getattr(model_client, "_base_url") + uri_reference = str(getattr(url, "_uri_reference")) + + print(f"uri_reference: {uri_reference}") + + assert "/v1/" in uri_reference + + def test_exception_raising(): # this tests if the router raises an exception when invalid params are set # in this test both deployments have bad keys - Keep this test. It validates if the router raises the most recent exception