refactor: move to new check key in limits logic

uses redis increment cache logic

ensures tpm/rpm logic works well across instances
This commit is contained in:
Krrish Dholakia 2025-04-15 17:43:31 -07:00
parent 5266c8adf7
commit 937a6e63ed
2 changed files with 131 additions and 16 deletions

View file

@ -1,7 +1,17 @@
import asyncio
import sys
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (
TYPE_CHECKING,
Any,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
cast,
)
from fastapi import HTTPException
from pydantic import BaseModel
@ -32,11 +42,16 @@ else:
class CacheObject(TypedDict):
current_global_requests: Optional[dict]
request_count_api_key: Optional[dict]
request_count_api_key: Optional[int]
request_count_api_key_model: Optional[dict]
request_count_user_id: Optional[dict]
request_count_team_id: Optional[dict]
request_count_end_user_id: Optional[dict]
rpm_api_key: Optional[int]
tpm_api_key: Optional[int]
RateLimitGroups = Literal["request_count", "tpm", "rpm"]
class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
@ -58,6 +73,69 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
except Exception:
pass
def _get_current_usage_key(
self,
user_api_key_dict: UserAPIKeyAuth,
precise_minute: str,
data: dict,
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
group: RateLimitGroups,
) -> str:
if rate_limit_type == "key":
return f"{user_api_key_dict.api_key}::{precise_minute}::{group}"
elif rate_limit_type == "model_per_key":
return f"{user_api_key_dict.api_key}::{data.get('model')}::{precise_minute}::{group}"
elif rate_limit_type == "user":
return f"{user_api_key_dict.user_id}::{precise_minute}::{group}"
elif rate_limit_type == "customer":
return f"{user_api_key_dict.end_user_id}::{precise_minute}::{group}"
elif rate_limit_type == "team":
return f"{user_api_key_dict.team_id}::{precise_minute}::{group}"
else:
raise ValueError(f"Invalid rate limit type: {rate_limit_type}")
async def check_key_in_limits_v2(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
max_parallel_requests: int,
precise_minute: str,
tpm_limit: int,
rpm_limit: int,
current_rpm: int,
current_tpm: int,
current_requests: int,
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
):
verbose_proxy_logger.info(
f"Current Usage of {rate_limit_type} in this minute: {current_requests}, {current_tpm}, {current_rpm}"
)
if (
current_requests >= max_parallel_requests
or current_tpm >= tpm_limit
or current_rpm >= rpm_limit
):
raise self.raise_rate_limit_error(
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
)
## INCREMENT CURRENT USAGE
increment_list: List[Tuple[str, int]] = []
for group in ["request_count", "rpm"]:
key = self._get_current_usage_key(
user_api_key_dict=user_api_key_dict,
precise_minute=precise_minute,
data=data,
rate_limit_type=rate_limit_type,
group=cast(RateLimitGroups, group),
)
increment_list.append((key, 1))
await self._increment_value_list_in_current_window(
increment_list=increment_list,
ttl=60,
)
async def check_key_in_limits(
self,
user_api_key_dict: UserAPIKeyAuth,
@ -150,6 +228,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
request_count_user_id: Optional[str],
request_count_team_id: Optional[str],
request_count_end_user_id: Optional[str],
rpm_api_key: Optional[str],
tpm_api_key: Optional[str],
parent_otel_span: Optional[Span] = None,
) -> CacheObject:
keys = [
@ -159,6 +239,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
request_count_user_id,
request_count_team_id,
request_count_end_user_id,
rpm_api_key,
tpm_api_key,
]
results = await self.internal_usage_cache.async_batch_get_cache(
keys=keys,
@ -173,6 +255,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
request_count_user_id=None,
request_count_team_id=None,
request_count_end_user_id=None,
rpm_api_key=None,
tpm_api_key=None,
)
return CacheObject(
@ -182,6 +266,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
request_count_user_id=results[3],
request_count_team_id=results[4],
request_count_end_user_id=results[5],
rpm_api_key=results[6],
tpm_api_key=results[7],
)
async def async_pre_call_hook( # noqa: PLR0915
@ -261,6 +347,12 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
if api_key is not None
else None
),
rpm_api_key=(
f"{api_key}::{precise_minute}::rpm" if api_key is not None else None
),
tpm_api_key=(
f"{api_key}::{precise_minute}::tpm" if api_key is not None else None
),
request_count_api_key_model=(
f"{api_key}::{_model}::{precise_minute}::request_count"
if api_key is not None and _model is not None
@ -286,20 +378,33 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
if api_key is not None:
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# CHECK IF REQUEST ALLOWED for key
await self.check_key_in_limits(
await self.check_key_in_limits_v2(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=max_parallel_requests,
current=cache_objects["request_count_api_key"],
request_count_api_key=request_count_api_key,
precise_minute=precise_minute,
tpm_limit=tpm_limit,
rpm_limit=rpm_limit,
current_rpm=cache_objects["rpm_api_key"] or 0,
current_tpm=cache_objects["tpm_api_key"] or 0,
current_requests=cache_objects["request_count_api_key"] or 0,
rate_limit_type="key",
values_to_update_in_cache=values_to_update_in_cache,
)
# await self.check_key_in_limits(
# user_api_key_dict=user_api_key_dict,
# cache=cache,
# data=data,
# call_type=call_type,
# max_parallel_requests=max_parallel_requests,
# current=cache_objects["request_count_api_key"],
# request_count_api_key=request_count_api_key,
# tpm_limit=tpm_limit,
# rpm_limit=rpm_limit,
# rate_limit_type="key",
# values_to_update_in_cache=values_to_update_in_cache,
# )
# Check if request under RPM/TPM per model for a given API Key
if (
get_key_model_tpm_limit(user_api_key_dict) is not None
@ -441,13 +546,13 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
values_to_update_in_cache=values_to_update_in_cache,
)
asyncio.create_task(
self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) # don't block execution for cache updates
)
# asyncio.create_task(
# self.internal_usage_cache.async_batch_set_cache(
# cache_list=values_to_update_in_cache,
# ttl=60,
# litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
# ) # don't block execution for cache updates
# )
return
@ -810,6 +915,7 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
"""
Retrieve the key's remaining rate limits.
"""
return
api_key = user_api_key_dict.api_key
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")