refactor: move to new check key in limits logic

uses redis increment cache logic

ensures tpm/rpm logic works well across instances
This commit is contained in:
Krrish Dholakia 2025-04-15 17:43:31 -07:00
parent 5266c8adf7
commit 937a6e63ed
2 changed files with 131 additions and 16 deletions

View file

@ -1,7 +1,17 @@
import asyncio
import sys
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (
TYPE_CHECKING,
Any,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
cast,
)
from fastapi import HTTPException
from pydantic import BaseModel
@ -32,11 +42,16 @@ else:
class CacheObject(TypedDict):
current_global_requests: Optional[dict]
request_count_api_key: Optional[dict]
request_count_api_key: Optional[int]
request_count_api_key_model: Optional[dict]
request_count_user_id: Optional[dict]
request_count_team_id: Optional[dict]
request_count_end_user_id: Optional[dict]
rpm_api_key: Optional[int]
tpm_api_key: Optional[int]
RateLimitGroups = Literal["request_count", "tpm", "rpm"]
class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
@ -58,6 +73,69 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
except Exception:
pass
def _get_current_usage_key(
self,
user_api_key_dict: UserAPIKeyAuth,
precise_minute: str,
data: dict,
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
group: RateLimitGroups,
) -> str:
if rate_limit_type == "key":
return f"{user_api_key_dict.api_key}::{precise_minute}::{group}"
elif rate_limit_type == "model_per_key":
return f"{user_api_key_dict.api_key}::{data.get('model')}::{precise_minute}::{group}"
elif rate_limit_type == "user":
return f"{user_api_key_dict.user_id}::{precise_minute}::{group}"
elif rate_limit_type == "customer":
return f"{user_api_key_dict.end_user_id}::{precise_minute}::{group}"
elif rate_limit_type == "team":
return f"{user_api_key_dict.team_id}::{precise_minute}::{group}"
else:
raise ValueError(f"Invalid rate limit type: {rate_limit_type}")
async def check_key_in_limits_v2(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
max_parallel_requests: int,
precise_minute: str,
tpm_limit: int,
rpm_limit: int,
current_rpm: int,
current_tpm: int,
current_requests: int,
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
):
verbose_proxy_logger.info(
f"Current Usage of {rate_limit_type} in this minute: {current_requests}, {current_tpm}, {current_rpm}"
)
if (
current_requests >= max_parallel_requests
or current_tpm >= tpm_limit
or current_rpm >= rpm_limit
):
raise self.raise_rate_limit_error(
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
)
## INCREMENT CURRENT USAGE
increment_list: List[Tuple[str, int]] = []
for group in ["request_count", "rpm"]:
key = self._get_current_usage_key(
user_api_key_dict=user_api_key_dict,
precise_minute=precise_minute,
data=data,
rate_limit_type=rate_limit_type,
group=cast(RateLimitGroups, group),
)
increment_list.append((key, 1))
await self._increment_value_list_in_current_window(
increment_list=increment_list,
ttl=60,
)
async def check_key_in_limits(
self,
user_api_key_dict: UserAPIKeyAuth,
@ -150,6 +228,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
request_count_user_id: Optional[str],
request_count_team_id: Optional[str],
request_count_end_user_id: Optional[str],
rpm_api_key: Optional[str],
tpm_api_key: Optional[str],
parent_otel_span: Optional[Span] = None,
) -> CacheObject:
keys = [
@ -159,6 +239,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
request_count_user_id,
request_count_team_id,
request_count_end_user_id,
rpm_api_key,
tpm_api_key,
]
results = await self.internal_usage_cache.async_batch_get_cache(
keys=keys,
@ -173,6 +255,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
request_count_user_id=None,
request_count_team_id=None,
request_count_end_user_id=None,
rpm_api_key=None,
tpm_api_key=None,
)
return CacheObject(
@ -182,6 +266,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
request_count_user_id=results[3],
request_count_team_id=results[4],
request_count_end_user_id=results[5],
rpm_api_key=results[6],
tpm_api_key=results[7],
)
async def async_pre_call_hook( # noqa: PLR0915
@ -261,6 +347,12 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
if api_key is not None
else None
),
rpm_api_key=(
f"{api_key}::{precise_minute}::rpm" if api_key is not None else None
),
tpm_api_key=(
f"{api_key}::{precise_minute}::tpm" if api_key is not None else None
),
request_count_api_key_model=(
f"{api_key}::{_model}::{precise_minute}::request_count"
if api_key is not None and _model is not None
@ -286,20 +378,33 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
if api_key is not None:
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# CHECK IF REQUEST ALLOWED for key
await self.check_key_in_limits(
await self.check_key_in_limits_v2(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=max_parallel_requests,
current=cache_objects["request_count_api_key"],
request_count_api_key=request_count_api_key,
precise_minute=precise_minute,
tpm_limit=tpm_limit,
rpm_limit=rpm_limit,
current_rpm=cache_objects["rpm_api_key"] or 0,
current_tpm=cache_objects["tpm_api_key"] or 0,
current_requests=cache_objects["request_count_api_key"] or 0,
rate_limit_type="key",
values_to_update_in_cache=values_to_update_in_cache,
)
# await self.check_key_in_limits(
# user_api_key_dict=user_api_key_dict,
# cache=cache,
# data=data,
# call_type=call_type,
# max_parallel_requests=max_parallel_requests,
# current=cache_objects["request_count_api_key"],
# request_count_api_key=request_count_api_key,
# tpm_limit=tpm_limit,
# rpm_limit=rpm_limit,
# rate_limit_type="key",
# values_to_update_in_cache=values_to_update_in_cache,
# )
# Check if request under RPM/TPM per model for a given API Key
if (
get_key_model_tpm_limit(user_api_key_dict) is not None
@ -441,13 +546,13 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
values_to_update_in_cache=values_to_update_in_cache,
)
asyncio.create_task(
self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) # don't block execution for cache updates
)
# asyncio.create_task(
# self.internal_usage_cache.async_batch_set_cache(
# cache_list=values_to_update_in_cache,
# ttl=60,
# litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
# ) # don't block execution for cache updates
# )
return
@ -810,6 +915,7 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
"""
Retrieve the key's remaining rate limits.
"""
return
api_key = user_api_key_dict.api_key
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")

View file

@ -5,7 +5,7 @@ Base class across routing strategies to abstract commmon functions like batch in
import asyncio
import threading
from abc import ABC
from typing import List, Optional, Set, Union
from typing import List, Optional, Set, Tuple, Union
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache
@ -53,6 +53,15 @@ class BaseRoutingStrategy(ABC):
except asyncio.CancelledError:
pass
async def _increment_value_list_in_current_window(
self, increment_list: List[Tuple[str, int]], ttl: int
):
"""
Increment a list of values in the current window
"""
for key, value in increment_list:
await self._increment_value_in_current_window(key=key, value=value, ttl=ttl)
async def _increment_value_in_current_window(
self, key: str, value: Union[int, float], ttl: int
):