diff --git a/litellm/router.py b/litellm/router.py index 629938158d..3558fa574e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5345,6 +5345,13 @@ class Router: client = self.cache.get_cache( key=cache_key, local_only=True, parent_otel_span=parent_otel_span ) + if client is None: + InitalizeOpenAISDKClient.set_max_parallel_requests_client( + litellm_router_instance=self, model=deployment + ) + client = self.cache.get_cache( + key=cache_key, local_only=True, parent_otel_span=parent_otel_span + ) return client elif client_type == "async": if kwargs.get("stream") is True: diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py index 4843b01e90..34c4a7534c 100644 --- a/litellm/router_utils/client_initalization_utils.py +++ b/litellm/router_utils/client_initalization_utils.py @@ -45,18 +45,11 @@ class InitalizeOpenAISDKClient: return True @staticmethod - def set_client( # noqa: PLR0915 + def set_max_parallel_requests_client( litellm_router_instance: LitellmRouter, model: dict ): - """ - - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 - - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 - """ - client_ttl = litellm_router_instance.client_ttl litellm_params = model.get("litellm_params", {}) - model_name = litellm_params.get("model") model_id = model["model_info"]["id"] - # ### IF RPM SET - initialize a semaphore ### rpm = litellm_params.get("rpm", None) tpm = litellm_params.get("tpm", None) max_parallel_requests = litellm_params.get("max_parallel_requests", None) @@ -75,6 +68,19 @@ class InitalizeOpenAISDKClient: local_only=True, ) + @staticmethod + def set_client( # noqa: PLR0915 + litellm_router_instance: LitellmRouter, model: dict + ): + """ + - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 + - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 + """ + client_ttl = litellm_router_instance.client_ttl + litellm_params = model.get("litellm_params", {}) + model_name = litellm_params.get("model") + model_id = model["model_info"]["id"] + #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## custom_llm_provider = litellm_params.get("custom_llm_provider") custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" @@ -185,7 +191,6 @@ class InitalizeOpenAISDKClient: organization_env_name = organization.replace("os.environ/", "") organization = get_secret_str(organization_env_name) litellm_params["organization"] = organization - else: _api_key = api_key # type: ignore if _api_key is not None and isinstance(_api_key, str): diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 5003499ba9..20a2f28c95 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -320,91 +320,6 @@ def test_router_order(): assert response._hidden_params["model_id"] == "1" -@pytest.mark.parametrize("num_retries", [None, 2]) -@pytest.mark.parametrize("max_retries", [None, 4]) -def test_router_num_retries_init(num_retries, max_retries): - """ - - test when num_retries set v/s not - - test client value when max retries set v/s not - """ - router = Router( - model_list=[ - { - "model_name": "gpt-3.5-turbo", # openai model name - "litellm_params": { # params for litellm completion/embedding call - "model": "azure/chatgpt-v-2", - "api_key": "bad-key", - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE"), - "max_retries": max_retries, - }, - "model_info": {"id": 12345}, - }, - ], - num_retries=num_retries, - ) - - if num_retries is not None: - assert router.num_retries == num_retries - else: - assert router.num_retries == openai.DEFAULT_MAX_RETRIES - - model_client = router._get_client( - {"model_info": {"id": 12345}}, client_type="async", kwargs={} - ) - - if max_retries is not None: - assert getattr(model_client, "max_retries") == max_retries - else: - assert getattr(model_client, "max_retries") == 0 - - -@pytest.mark.parametrize( - "timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)] -) -@pytest.mark.parametrize("ssl_verify", [True, False]) -def test_router_timeout_init(timeout, ssl_verify): - """ - Allow user to pass httpx.Timeout - - related issue - https://github.com/BerriAI/litellm/issues/3162 - """ - litellm.ssl_verify = ssl_verify - - router = Router( - model_list=[ - { - "model_name": "test-model", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_base": os.getenv("AZURE_API_BASE"), - "api_version": os.getenv("AZURE_API_VERSION"), - "timeout": timeout, - }, - "model_info": {"id": 1234}, - } - ] - ) - - model_client = router._get_client( - deployment={"model_info": {"id": 1234}}, client_type="sync_client", kwargs={} - ) - - assert getattr(model_client, "timeout") == timeout - - print(f"vars model_client: {vars(model_client)}") - http_client = getattr(model_client, "_client") - print(f"http client: {vars(http_client)}, ssl_Verify={ssl_verify}") - if ssl_verify == False: - assert http_client._transport._pool._ssl_context.verify_mode.name == "CERT_NONE" - else: - assert ( - http_client._transport._pool._ssl_context.verify_mode.name - == "CERT_REQUIRED" - ) - - @pytest.mark.parametrize("sync_mode", [False, True]) @pytest.mark.asyncio async def test_router_retries(sync_mode):