diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 242c013d67..76f08892ef 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -222,7 +222,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ) # check if below limit if current_global_requests is None: - current_global_requests = 1 + current_global_requests = 0 # if above -> raise error if current_global_requests >= global_max_parallel_requests: return self.raise_rate_limit_error( diff --git a/tests/local_testing/test_parallel_request_limiter.py b/tests/local_testing/test_parallel_request_limiter.py index 8b34e03454..e4b04d2fb0 100644 --- a/tests/local_testing/test_parallel_request_limiter.py +++ b/tests/local_testing/test_parallel_request_limiter.py @@ -67,6 +67,35 @@ async def test_global_max_parallel_requests(): except Exception as e: print(e) + # Test: n requests (up to global_max_parallel_requests) must succeed + # and (n+1)th request must fail. + global_max_parallel_requests = 4 + for _ in range(global_max_parallel_requests): + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={ + "metadata": { + "global_max_parallel_requests": global_max_parallel_requests + } + }, + call_type="", + ) + try: + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={ + "metadata": { + "global_max_parallel_requests": global_max_parallel_requests + } + }, + call_type="", + ) + pytest.fail("Expected call to fail") + except Exception as e: + print(e) + @pytest.mark.flaky(retries=6, delay=1) @pytest.mark.asyncio