From 594ca947c8846a0e9f3200663d558add714dc58c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 15 May 2024 20:16:11 -0700 Subject: [PATCH] fix(parallel_request_limiter.py): fix max parallel request limiter on retries --- .../proxy/hooks/parallel_request_limiter.py | 54 +++++++++++++++++-- litellm/proxy/proxy_server.py | 19 +++++++ .../tests/test_parallel_request_limiter.py | 31 +++++++++++ litellm/utils.py | 2 +- 4 files changed, 100 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 28e6d1853..26238b6c0 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -79,6 +79,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): max_parallel_requests = user_api_key_dict.max_parallel_requests if max_parallel_requests is None: max_parallel_requests = sys.maxsize + global_max_parallel_requests = data.get("metadata", {}).get( + "global_max_parallel_requests", None + ) tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) if tpm_limit is None: tpm_limit = sys.maxsize @@ -91,6 +94,24 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # Setup values # ------------ + if global_max_parallel_requests is not None: + # get value from cache + _key = "global_max_parallel_requests" + current_global_requests = await cache.async_get_cache( + key=_key, local_only=True + ) + # check if below limit + if current_global_requests is None: + current_global_requests = 1 + # if above -> raise error + if current_global_requests >= global_max_parallel_requests: + raise HTTPException( + status_code=429, detail="Max parallel request limit reached." + ) + # if below -> increment + else: + await cache.async_increment_cache(key=_key, value=1, local_only=True) + current_date = datetime.now().strftime("%Y-%m-%d") current_hour = datetime.now().strftime("%H") current_minute = datetime.now().strftime("%M") @@ -207,6 +228,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING") + global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( + "global_max_parallel_requests", None + ) user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None @@ -222,6 +246,14 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # Setup values # ------------ + if global_max_parallel_requests is not None: + # get value from cache + _key = "global_max_parallel_requests" + # decrement + await self.user_api_key_cache.async_increment_cache( + key=_key, value=-1, local_only=True + ) + current_date = datetime.now().strftime("%Y-%m-%d") current_hour = datetime.now().strftime("%H") current_minute = datetime.now().strftime("%M") @@ -336,6 +368,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: self.print_verbose(f"Inside Max Parallel Request Failure Hook") + global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( + "global_max_parallel_requests", None + ) user_api_key = ( kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None) ) @@ -347,17 +382,26 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): return ## decrement call count if call failed - if ( - hasattr(kwargs["exception"], "status_code") - and kwargs["exception"].status_code == 429 - and "Max parallel request limit reached" in str(kwargs["exception"]) - ): + if "Max parallel request limit reached" in str(kwargs["exception"]): pass # ignore failed calls due to max limit being reached else: # ------------ # Setup values # ------------ + if global_max_parallel_requests is not None: + # get value from cache + _key = "global_max_parallel_requests" + current_global_requests = ( + await self.user_api_key_cache.async_get_cache( + key=_key, local_only=True + ) + ) + # decrement + await self.user_api_key_cache.async_increment_cache( + key=_key, value=-1, local_only=True + ) + current_date = datetime.now().strftime("%Y-%m-%d") current_hour = datetime.now().strftime("%H") current_minute = datetime.now().strftime("%M") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7649f74d5..98be701b0 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2848,6 +2848,7 @@ class ProxyConfig: """ Pull from DB, read general settings value """ + global general_settings if db_general_settings is None: return _general_settings = dict(db_general_settings) @@ -3690,6 +3691,9 @@ async def chat_completion( data["metadata"]["user_api_key_alias"] = getattr( user_api_key_dict, "key_alias", None ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id data["metadata"]["user_api_key_team_id"] = getattr( @@ -3957,6 +3961,9 @@ async def completion( data["metadata"]["user_api_key_team_id"] = getattr( user_api_key_dict, "team_id", None ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) data["metadata"]["user_api_key_team_alias"] = getattr( user_api_key_dict, "team_alias", None ) @@ -4151,6 +4158,9 @@ async def embeddings( data["metadata"]["user_api_key_alias"] = getattr( user_api_key_dict, "key_alias", None ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["user_api_key_team_id"] = getattr( user_api_key_dict, "team_id", None @@ -4349,6 +4359,9 @@ async def image_generation( data["metadata"]["user_api_key_alias"] = getattr( user_api_key_dict, "key_alias", None ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["user_api_key_team_id"] = getattr( user_api_key_dict, "team_id", None @@ -4529,6 +4542,9 @@ async def audio_transcriptions( data["metadata"]["user_api_key_team_id"] = getattr( user_api_key_dict, "team_id", None ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) data["metadata"]["user_api_key_team_alias"] = getattr( user_api_key_dict, "team_alias", None ) @@ -4726,6 +4742,9 @@ async def moderations( "authorization", None ) # do not store the original `sk-..` api key in the db data["metadata"]["headers"] = _headers + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) data["metadata"]["user_api_key_alias"] = getattr( user_api_key_dict, "key_alias", None ) diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index d0a28926e..00da199d9 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -28,6 +28,37 @@ from datetime import datetime ## On Request failure +@pytest.mark.asyncio +async def test_global_max_parallel_requests(): + """ + Test if ParallelRequestHandler respects 'global_max_parallel_requests' + + data["metadata"]["global_max_parallel_requests"] + """ + global_max_parallel_requests = 0 + _api_key = "sk-12345" + _api_key = hash_token("sk-12345") + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=100) + local_cache = DualCache() + parallel_request_handler = MaxParallelRequestsHandler() + + for _ in range(3): + try: + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={ + "metadata": { + "global_max_parallel_requests": global_max_parallel_requests + } + }, + call_type="", + ) + pytest.fail("Expected call to fail") + except Exception as e: + pass + + @pytest.mark.asyncio async def test_pre_call_hook(): """ diff --git a/litellm/utils.py b/litellm/utils.py index 7df79f373..c84df360a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2579,7 +2579,7 @@ class Logging: response_obj=result, start_time=start_time, end_time=end_time, - ) + ) # type: ignore if callable(callback): # custom logger functions await customLogger.async_log_event( kwargs=self.model_call_details,