From c1ec1a3ed6d51e8d4ab98f6986cd64528b329d35 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 11 Mar 2025 17:52:05 -0700 Subject: [PATCH 1/2] test: remove redundant tests --- tests/local_testing/test_router.py | 85 ------------------------------ 1 file changed, 85 deletions(-) diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 5003499ba9..20a2f28c95 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -320,91 +320,6 @@ def test_router_order(): assert response._hidden_params["model_id"] == "1" -@pytest.mark.parametrize("num_retries", [None, 2]) -@pytest.mark.parametrize("max_retries", [None, 4]) -def test_router_num_retries_init(num_retries, max_retries): - """ - - test when num_retries set v/s not - - test client value when max retries set v/s not - """ - router = Router( - model_list=[ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": "bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE"), - "max_retries": max_retries, - }, - "model_info": {"id": 12345}, - }, - ], - num_retries=num_retries, - ) - - if num_retries is not None: - assert router.num_retries == num_retries - else: - assert router.num_retries == openai.DEFAULT_MAX_RETRIES - - model_client = router._get_client( - {"model_info": {"id": 12345}}, client_type="async", kwargs={} - ) - - if max_retries is not None: - assert getattr(model_client, "max_retries") == max_retries - else: - assert getattr(model_client, "max_retries") == 0 - - -@pytest.mark.parametrize( - "timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)] -) -@pytest.mark.parametrize("ssl_verify", [True, False]) -def test_router_timeout_init(timeout, ssl_verify): - """ - Allow user to pass httpx.Timeout - - related issue - https://github.com/BerriAI/litellm/issues/3162 - """ - litellm.ssl_verify = ssl_verify - - router = Router( - model_list=[ - { - "model_name": "test-model", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION"), - "timeout": timeout, - }, - "model_info": {"id": 1234}, - } - ] - ) - - model_client = router._get_client( - deployment={"model_info": {"id": 1234}}, client_type="sync_client", kwargs={} - ) - - assert getattr(model_client, "timeout") == timeout - - print(f"vars model_client: {vars(model_client)}") - http_client = getattr(model_client, "_client") - print(f"http client: {vars(http_client)}, ssl_Verify={ssl_verify}") - if ssl_verify == False: - assert http_client._transport._pool._ssl_context.verify_mode.name == "CERT_NONE" - else: - assert ( - http_client._transport._pool._ssl_context.verify_mode.name - == "CERT_REQUIRED" - ) - - @pytest.mark.parametrize("sync_mode", [False, True]) @pytest.mark.asyncio async def test_router_retries(sync_mode): From 507504906e2a6b492707b658e5155a2a6f472bc6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 11 Mar 2025 18:25:48 -0700 Subject: [PATCH 2/2] fix: fix max parallel requests client --- litellm/router.py | 7 ++++++ .../client_initalization_utils.py | 25 ++++++++++++------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index f573bf65a6..21d1dc64c6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5346,6 +5346,13 @@ class Router: client = self.cache.get_cache( key=cache_key, local_only=True, parent_otel_span=parent_otel_span ) + if client is None: + InitalizeOpenAISDKClient.set_max_parallel_requests_client( + litellm_router_instance=self, model=deployment + ) + client = self.cache.get_cache( + key=cache_key, local_only=True, parent_otel_span=parent_otel_span + ) return client elif client_type == "async": if kwargs.get("stream") is True: diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py index 7956d8c72e..d896dbb88c 100644 --- a/litellm/router_utils/client_initalization_utils.py +++ b/litellm/router_utils/client_initalization_utils.py @@ -54,18 +54,11 @@ class InitalizeOpenAISDKClient: return True @staticmethod - def set_client( # noqa: PLR0915 + def set_max_parallel_requests_client( litellm_router_instance: LitellmRouter, model: dict ): - """ - - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 - - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 - """ - client_ttl = litellm_router_instance.client_ttl litellm_params = model.get("litellm_params", {}) - model_name = litellm_params.get("model") model_id = model["model_info"]["id"] - # ### IF RPM SET - initialize a semaphore ### rpm = litellm_params.get("rpm", None) tpm = litellm_params.get("tpm", None) max_parallel_requests = litellm_params.get("max_parallel_requests", None) @@ -84,6 +77,19 @@ class InitalizeOpenAISDKClient: local_only=True, ) + @staticmethod + def set_client( # noqa: PLR0915 + litellm_router_instance: LitellmRouter, model: dict + ): + """ + - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 + - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 + """ + client_ttl = litellm_router_instance.client_ttl + litellm_params = model.get("litellm_params", {}) + model_name = litellm_params.get("model") + model_id = model["model_info"]["id"] + #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## custom_llm_provider = litellm_params.get("custom_llm_provider") custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" @@ -233,7 +239,8 @@ class InitalizeOpenAISDKClient: if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) elif ( - not api_key and azure_ad_token_provider is None + not api_key + and azure_ad_token_provider is None and litellm.enable_azure_ad_token_refresh is True ): try: