diff --git a/litellm/router.py b/litellm/router.py index a80dcf5ad4..d60767f3fa 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2562,7 +2562,7 @@ class Router: """ model_id = deployment["model_info"]["id"] if client_type == "max_parallel_requests": - cache_key = "{}_max_parallel_requests".format(model_id) + cache_key = "{}_max_parallel_requests_client".format(model_id) client = self.cache.get_cache(key=cache_key, local_only=True) return client elif client_type == "async": diff --git a/litellm/tests/test_router_max_parallel_requests.py b/litellm/tests/test_router_max_parallel_requests.py index 43c3694ff7..f9cac6aafb 100644 --- a/litellm/tests/test_router_max_parallel_requests.py +++ b/litellm/tests/test_router_max_parallel_requests.py @@ -7,6 +7,7 @@ import pytest sys.path.insert(0, os.path.abspath("../..")) import litellm from litellm.utils import calculate_max_parallel_requests +from typing import Optional """ - only rpm @@ -19,7 +20,7 @@ from litellm.utils import calculate_max_parallel_requests max_parallel_requests_values = [None, 10] -tpm_values = [None, 20] +tpm_values = [None, 20, 300000] rpm_values = [None, 30] default_max_parallel_requests = [None, 40] @@ -46,8 +47,69 @@ def test_scenario(max_parallel_requests, tpm, rpm, default_max_parallel_requests elif rpm is not None: assert rpm == calculated_max_parallel_requests elif tpm is not None: - assert int(tpm / 1000 / 6) == calculated_max_parallel_requests + calculated_rpm = int(tpm / 1000 / 6) + if calculated_rpm == 0: + calculated_rpm = 1 + print( + f"test calculated_rpm: {calculated_rpm}, calculated_max_parallel_requests={calculated_max_parallel_requests}" + ) + assert calculated_rpm == calculated_max_parallel_requests elif default_max_parallel_requests is not None: assert calculated_max_parallel_requests == default_max_parallel_requests else: assert calculated_max_parallel_requests is None + + +@pytest.mark.parametrize( + "max_parallel_requests, tpm, rpm, default_max_parallel_requests", + [ + (mp, tp, rp, dmp) + for mp in max_parallel_requests_values + for tp in tpm_values + for rp in rpm_values + for dmp in default_max_parallel_requests + ], +) +def test_setting_mpr_limits_per_model( + max_parallel_requests, tpm, rpm, default_max_parallel_requests +): + deployment = { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "max_parallel_requests": max_parallel_requests, + "tpm": tpm, + "rpm": rpm, + }, + "model_info": {"id": "my-unique-id"}, + } + + router = litellm.Router( + model_list=[deployment], + default_max_parallel_requests=default_max_parallel_requests, + ) + + mpr_client: Optional[asyncio.Semaphore] = router._get_client( + deployment=deployment, + kwargs={}, + client_type="max_parallel_requests", + ) + + if max_parallel_requests is not None: + assert max_parallel_requests == mpr_client._value + elif rpm is not None: + assert rpm == mpr_client._value + elif tpm is not None: + calculated_rpm = int(tpm / 1000 / 6) + if calculated_rpm == 0: + calculated_rpm = 1 + print( + f"test calculated_rpm: {calculated_rpm}, calculated_max_parallel_requests={mpr_client._value}" + ) + assert calculated_rpm == mpr_client._value + elif default_max_parallel_requests is not None: + assert mpr_client._value == default_max_parallel_requests + else: + assert mpr_client is None + + # raise Exception("it worked!") diff --git a/litellm/utils.py b/litellm/utils.py index 566ef20996..0b4fb46607 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5429,7 +5429,10 @@ def calculate_max_parallel_requests( elif rpm is not None: return rpm elif tpm is not None: - return int(tpm / 1000 / 6) + calculated_rpm = int(tpm / 1000 / 6) + if calculated_rpm == 0: + calculated_rpm = 1 + return calculated_rpm elif default_max_parallel_requests is not None: return default_max_parallel_requests return None