From 9f6e90e17d12d25f1549d9f3992a9ac65b726047 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 20 Apr 2024 12:56:54 -0700 Subject: [PATCH] test(test_router_max_parallel_requests.py): more extensive testing for setting max parallel requests --- litellm/router.py | 2 +- .../test_router_max_parallel_requests.py | 66 ++++++++++++++++++- litellm/utils.py | 5 +- 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index a80dcf5ad4..d60767f3fa 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2562,7 +2562,7 @@ class Router: """ model_id = deployment["model_info"]["id"] if client_type == "max_parallel_requests": - cache_key = "{}_max_parallel_requests".format(model_id) + cache_key = "{}_max_parallel_requests_client".format(model_id) client = self.cache.get_cache(key=cache_key, local_only=True) return client elif client_type == "async": diff --git a/litellm/tests/test_router_max_parallel_requests.py b/litellm/tests/test_router_max_parallel_requests.py index 43c3694ff7..f9cac6aafb 100644 --- a/litellm/tests/test_router_max_parallel_requests.py +++ b/litellm/tests/test_router_max_parallel_requests.py @@ -7,6 +7,7 @@ import pytest sys.path.insert(0, os.path.abspath("../..")) import litellm from litellm.utils import calculate_max_parallel_requests +from typing import Optional """ - only rpm @@ -19,7 +20,7 @@ from litellm.utils import calculate_max_parallel_requests max_parallel_requests_values = [None, 10] -tpm_values = [None, 20] +tpm_values = [None, 20, 300000] rpm_values = [None, 30] default_max_parallel_requests = [None, 40] @@ -46,8 +47,69 @@ def test_scenario(max_parallel_requests, tpm, rpm, default_max_parallel_requests elif rpm is not None: assert rpm == calculated_max_parallel_requests elif tpm is not None: - assert int(tpm / 1000 / 6) == calculated_max_parallel_requests + calculated_rpm = int(tpm / 1000 / 6) + if calculated_rpm == 0: + calculated_rpm = 1 + print( + f"test calculated_rpm: {calculated_rpm}, calculated_max_parallel_requests={calculated_max_parallel_requests}" + ) + assert calculated_rpm == calculated_max_parallel_requests elif default_max_parallel_requests is not None: assert calculated_max_parallel_requests == default_max_parallel_requests else: assert calculated_max_parallel_requests is None + + +@pytest.mark.parametrize( + "max_parallel_requests, tpm, rpm, default_max_parallel_requests", + [ + (mp, tp, rp, dmp) + for mp in max_parallel_requests_values + for tp in tpm_values + for rp in rpm_values + for dmp in default_max_parallel_requests + ], +) +def test_setting_mpr_limits_per_model( + max_parallel_requests, tpm, rpm, default_max_parallel_requests +): + deployment = { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "max_parallel_requests": max_parallel_requests, + "tpm": tpm, + "rpm": rpm, + }, + "model_info": {"id": "my-unique-id"}, + } + + router = litellm.Router( + model_list=[deployment], + default_max_parallel_requests=default_max_parallel_requests, + ) + + mpr_client: Optional[asyncio.Semaphore] = router._get_client( + deployment=deployment, + kwargs={}, + client_type="max_parallel_requests", + ) + + if max_parallel_requests is not None: + assert max_parallel_requests == mpr_client._value + elif rpm is not None: + assert rpm == mpr_client._value + elif tpm is not None: + calculated_rpm = int(tpm / 1000 / 6) + if calculated_rpm == 0: + calculated_rpm = 1 + print( + f"test calculated_rpm: {calculated_rpm}, calculated_max_parallel_requests={mpr_client._value}" + ) + assert calculated_rpm == mpr_client._value + elif default_max_parallel_requests is not None: + assert mpr_client._value == default_max_parallel_requests + else: + assert mpr_client is None + + # raise Exception("it worked!") diff --git a/litellm/utils.py b/litellm/utils.py index 566ef20996..0b4fb46607 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5429,7 +5429,10 @@ def calculate_max_parallel_requests( elif rpm is not None: return rpm elif tpm is not None: - return int(tpm / 1000 / 6) + calculated_rpm = int(tpm / 1000 / 6) + if calculated_rpm == 0: + calculated_rpm = 1 + return calculated_rpm elif default_max_parallel_requests is not None: return default_max_parallel_requests return None