feat(parallel_request_limiter.py): move to doing an increment on check

reduces spillover (from 66 -> 2 at 10k+ requests in 1min.)
This commit is contained in:
Krrish Dholakia 2025-04-15 18:10:23 -07:00
parent 937a6e63ed
commit 897eb46320
3 changed files with 18 additions and 14 deletions

View file

@ -975,6 +975,7 @@ class RedisCache(BaseCache):
- increment_value: float
- ttl_seconds: int
"""
# don't waste a network request if there's nothing to increment
if len(increment_list) == 0:
return None

View file

@ -62,7 +62,7 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
self,
dual_cache=internal_usage_cache.dual_cache,
should_batch_redis_writes=True,
default_sync_interval=0.1,
default_sync_interval=0.01,
)
def print_verbose(self, print_statement):
@ -102,19 +102,13 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
precise_minute: str,
tpm_limit: int,
rpm_limit: int,
current_rpm: int,
current_tpm: int,
current_requests: int,
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
):
verbose_proxy_logger.info(
f"Current Usage of {rate_limit_type} in this minute: {current_requests}, {current_tpm}, {current_rpm}"
f"Current Usage of {rate_limit_type} in this minute: {current_tpm}"
)
if (
current_requests >= max_parallel_requests
or current_tpm >= tpm_limit
or current_rpm >= rpm_limit
):
if current_tpm >= tpm_limit:
raise self.raise_rate_limit_error(
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
)
@ -131,11 +125,16 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
)
increment_list.append((key, 1))
await self._increment_value_list_in_current_window(
results = await self._increment_value_list_in_current_window(
increment_list=increment_list,
ttl=60,
)
if results[0] >= max_parallel_requests or results[1] >= rpm_limit:
raise self.raise_rate_limit_error(
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
)
async def check_key_in_limits(
self,
user_api_key_dict: UserAPIKeyAuth,
@ -242,6 +241,7 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
rpm_api_key,
tpm_api_key,
]
results = await self.internal_usage_cache.async_batch_get_cache(
keys=keys,
parent_otel_span=parent_otel_span,
@ -385,9 +385,7 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
precise_minute=precise_minute,
tpm_limit=tpm_limit,
rpm_limit=rpm_limit,
current_rpm=cache_objects["rpm_api_key"] or 0,
current_tpm=cache_objects["tpm_api_key"] or 0,
current_requests=cache_objects["request_count_api_key"] or 0,
rate_limit_type="key",
)

View file

@ -55,12 +55,17 @@ class BaseRoutingStrategy(ABC):
async def _increment_value_list_in_current_window(
self, increment_list: List[Tuple[str, int]], ttl: int
):
) -> List[float]:
"""
Increment a list of values in the current window
"""
results = []
for key, value in increment_list:
await self._increment_value_in_current_window(key=key, value=value, ttl=ttl)
result = await self._increment_value_in_current_window(
key=key, value=value, ttl=ttl
)
results.append(result)
return results
async def _increment_value_in_current_window(
self, key: str, value: Union[int, float], ttl: int