diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 7764cf4e6..864ad5260 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -1,7 +1,7 @@ import sys import traceback from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union from fastapi import HTTPException @@ -53,6 +53,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): rpm_limit: int, request_count_api_key: str, rate_limit_type: Literal["user", "customer", "team"], + values_to_update_in_cache: List[Tuple[Any, Any]], ): current = await self.internal_usage_cache.async_get_cache( key=request_count_api_key, @@ -69,11 +70,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_tpm": 0, "current_rpm": 0, } - await self.internal_usage_cache.async_set_cache( - key=request_count_api_key, - value=new_val, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - ) + values_to_update_in_cache.append((request_count_api_key, new_val)) elif ( int(current["current_requests"]) < max_parallel_requests and current["current_tpm"] < tpm_limit @@ -85,11 +82,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_tpm": current["current_tpm"], "current_rpm": current["current_rpm"], } - await self.internal_usage_cache.async_set_cache( - key=request_count_api_key, - value=new_val, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - ) + values_to_update_in_cache.append((request_count_api_key, new_val)) else: raise HTTPException( status_code=429, @@ -148,6 +141,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): if rpm_limit is None: rpm_limit = sys.maxsize + values_to_update_in_cache: List[Tuple[Any, Any]] = ( + [] + ) # values that need to get updated in cache, will run a batch_set_cache after this function + # ------------ # Setup values # ------------ @@ -208,11 +205,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_tpm": 0, "current_rpm": 0, } - await self.internal_usage_cache.async_set_cache( - key=request_count_api_key, - value=new_val, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - ) + values_to_update_in_cache.append((request_count_api_key, new_val)) elif ( int(current["current_requests"]) < max_parallel_requests and current["current_tpm"] < tpm_limit @@ -224,11 +217,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_tpm": current["current_tpm"], "current_rpm": current["current_rpm"], } - await self.internal_usage_cache.async_set_cache( - key=request_count_api_key, - value=new_val, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - ) + values_to_update_in_cache.append((request_count_api_key, new_val)) else: return self.raise_rate_limit_error( additional_details=f"Hit limit for api_key: {api_key}. tpm_limit: {tpm_limit}, current_tpm {current['current_tpm']} , rpm_limit: {rpm_limit} current rpm {current['current_rpm']} " @@ -268,11 +257,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_tpm": 0, "current_rpm": 0, } - await self.internal_usage_cache.async_set_cache( - key=request_count_api_key, - value=new_val, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - ) + values_to_update_in_cache.append((request_count_api_key, new_val)) elif tpm_limit_for_model is not None or rpm_limit_for_model is not None: # Increase count for this token new_val = { @@ -295,11 +280,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): additional_details=f"Hit RPM limit for model: {_model} on api_key: {api_key}. rpm_limit: {rpm_limit_for_model}, current_rpm {current['current_rpm']} " ) else: - await self.internal_usage_cache.async_set_cache( - key=request_count_api_key, - value=new_val, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, - ) + values_to_update_in_cache.append((request_count_api_key, new_val)) _remaining_tokens = None _remaining_requests = None @@ -356,6 +337,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): tpm_limit=user_tpm_limit, rpm_limit=user_rpm_limit, rate_limit_type="user", + values_to_update_in_cache=values_to_update_in_cache, ) # TEAM RATE LIMITS @@ -384,6 +366,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): tpm_limit=team_tpm_limit, rpm_limit=team_rpm_limit, rate_limit_type="team", + values_to_update_in_cache=values_to_update_in_cache, ) # End-User Rate Limits @@ -417,8 +400,15 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): tpm_limit=end_user_tpm_limit, rpm_limit=end_user_rpm_limit, rate_limit_type="customer", + values_to_update_in_cache=values_to_update_in_cache, ) + await self.internal_usage_cache.async_batch_set_cache( + cache_list=values_to_update_in_cache, + ttl=60, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + ) + return async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):