diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 007004ca39..7bec1d5d66 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -5,6 +5,8 @@ from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException from litellm._logging import verbose_proxy_logger +from litellm import ModelResponse +from datetime import datetime class MaxParallelRequestsHandler(CustomLogger): @@ -35,16 +37,37 @@ class MaxParallelRequestsHandler(CustomLogger): return self.user_api_key_cache = cache # save the api key cache for updating the value + # ------------ + # Setup values + # ------------ + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + request_count_api_key = f"{api_key}::{precise_minute}::request_count" # CHECK IF REQUEST ALLOWED - request_count_api_key = f"{api_key}_request_count" - current = cache.get_cache(key=request_count_api_key) + current = cache.get_cache( + key=request_count_api_key + ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} self.print_verbose(f"current: {current}") if current is None: - cache.set_cache(request_count_api_key, 1) - elif int(current) < max_parallel_requests: + new_val = { + "current_requests": 1, + "current_tpm": 0, + "current_rpm": 0, + } + cache.set_cache(request_count_api_key, new_val) + elif int(current["current_requests"]) < max_parallel_requests: # Increase count for this token - cache.set_cache(request_count_api_key, int(current) + 1) + new_val = { + "current_requests": current["current_requests"] + 1, + "current_tpm": current["current_tpm"], + "current_rpm": current["current_rpm"], + } + cache.set_cache(request_count_api_key, new_val) else: raise HTTPException( status_code=429, detail="Max parallel request limit reached." @@ -60,12 +83,42 @@ class MaxParallelRequestsHandler(CustomLogger): if self.user_api_key_cache is None: return - request_count_api_key = f"{user_api_key}_request_count" - # Decrease count for this token - current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 - new_val = current - 1 + # ------------ + # Setup values + # ------------ + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens + + request_count_api_key = f"{user_api_key}::{precise_minute}::request_count" + + current = self.user_api_key_cache.get_cache(key=request_count_api_key) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } + + # ------------ + # Update usage + # ------------ + + new_val = { + "current_requests": current["current_requests"] - 1, + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"] + 1, + } + self.print_verbose(f"updated_value in success call: {new_val}") - self.user_api_key_cache.set_cache(request_count_api_key, new_val) + self.user_api_key_cache.set_cache( + request_count_api_key, new_val, ttl=60 + ) # store in cache for 1 min. except Exception as e: self.print_verbose(e) # noqa @@ -87,13 +140,40 @@ class MaxParallelRequestsHandler(CustomLogger): ): pass # ignore failed calls due to max limit being reached else: - request_count_api_key = f"{user_api_key}_request_count" - # Decrease count for this token - current = ( - self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 + # ------------ + # Setup values + # ------------ + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + request_count_api_key = ( + f"{user_api_key}::{precise_minute}::request_count" ) - new_val = current - 1 + + # ------------ + # Update usage + # ------------ + + current = self.user_api_key_cache.get_cache( + key=request_count_api_key + ) or { + "current_requests": 1, + "current_tpm": 0, + "current_rpm": 0, + } + + new_val = { + "current_requests": current["current_requests"] - 1, + "current_tpm": current["current_tpm"], + "current_rpm": current["current_rpm"], + } + self.print_verbose(f"updated_value in failure call: {new_val}") - self.user_api_key_cache.set_cache(request_count_api_key, new_val) + self.user_api_key_cache.set_cache( + request_count_api_key, new_val, ttl=60 + ) # save in cache for up to 1 min. except Exception as e: self.print_verbose(f"An exception occurred - {str(e)}") # noqa diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index 41c9d3c828..ff03e251da 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -19,6 +19,7 @@ from litellm.proxy.utils import ProxyLogging from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler +from datetime import datetime ## On Request received ## On Request success @@ -39,15 +40,19 @@ async def test_pre_call_hook(): user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" ) + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + request_count_api_key = f"{_api_key}::{precise_minute}::request_count" + print( - parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" - ) + parallel_request_handler.user_api_key_cache.get_cache(key=request_count_api_key) ) assert ( parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" - ) + key=request_count_api_key + )["current_requests"] == 1 ) @@ -66,10 +71,16 @@ async def test_success_call_hook(): user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" ) + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + request_count_api_key = f"{_api_key}::{precise_minute}::request_count" + assert ( parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" - ) + key=request_count_api_key + )["current_requests"] == 1 ) @@ -81,8 +92,8 @@ async def test_success_call_hook(): assert ( parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" - ) + key=request_count_api_key + )["current_requests"] == 0 ) @@ -101,10 +112,16 @@ async def test_failure_call_hook(): user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" ) + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + request_count_api_key = f"{_api_key}::{precise_minute}::request_count" + assert ( parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" - ) + key=request_count_api_key + )["current_requests"] == 1 ) @@ -119,8 +136,8 @@ async def test_failure_call_hook(): assert ( parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" - ) + key=request_count_api_key + )["current_requests"] == 0 ) @@ -175,10 +192,16 @@ async def test_normal_router_call(): user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" ) + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + request_count_api_key = f"{_api_key}::{precise_minute}::request_count" + assert ( parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" - ) + key=request_count_api_key + )["current_requests"] == 1 ) @@ -190,12 +213,13 @@ async def test_normal_router_call(): ) await asyncio.sleep(1) # success is done in a separate thread print(f"response: {response}") - value = parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" - ) - print(f"cache value: {value}") - assert value == 0 + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=request_count_api_key + )["current_requests"] + == 0 + ) @pytest.mark.asyncio @@ -240,10 +264,16 @@ async def test_streaming_router_call(): user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" ) + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + request_count_api_key = f"{_api_key}::{precise_minute}::request_count" + assert ( parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" - ) + key=request_count_api_key + )["current_requests"] == 1 ) @@ -257,12 +287,12 @@ async def test_streaming_router_call(): async for chunk in response: continue await asyncio.sleep(1) # success is done in a separate thread - value = parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=request_count_api_key + )["current_requests"] + == 0 ) - print(f"cache value: {value}") - - assert value == 0 @pytest.mark.asyncio @@ -307,10 +337,16 @@ async def test_bad_router_call(): user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" ) + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + request_count_api_key = f"{_api_key}::{precise_minute}::request_count" + assert ( parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" - ) + key=request_count_api_key + )["current_requests"] == 1 ) @@ -324,9 +360,9 @@ async def test_bad_router_call(): ) except: pass - value = parallel_request_handler.user_api_key_cache.get_cache( - key=f"{_api_key}_request_count" + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=request_count_api_key + )["current_requests"] + == 0 ) - print(f"cache value: {value}") - - assert value == 0