mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
refactor: move to new check key in limits logic
uses redis increment cache logic ensures tpm/rpm logic works well across instances
This commit is contained in:
parent
5266c8adf7
commit
937a6e63ed
2 changed files with 131 additions and 16 deletions
|
@ -1,7 +1,17 @@
|
|||
import asyncio
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
@ -32,11 +42,16 @@ else:
|
|||
|
||||
class CacheObject(TypedDict):
|
||||
current_global_requests: Optional[dict]
|
||||
request_count_api_key: Optional[dict]
|
||||
request_count_api_key: Optional[int]
|
||||
request_count_api_key_model: Optional[dict]
|
||||
request_count_user_id: Optional[dict]
|
||||
request_count_team_id: Optional[dict]
|
||||
request_count_end_user_id: Optional[dict]
|
||||
rpm_api_key: Optional[int]
|
||||
tpm_api_key: Optional[int]
|
||||
|
||||
|
||||
RateLimitGroups = Literal["request_count", "tpm", "rpm"]
|
||||
|
||||
|
||||
class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
|
||||
|
@ -58,6 +73,69 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_current_usage_key(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
precise_minute: str,
|
||||
data: dict,
|
||||
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
|
||||
group: RateLimitGroups,
|
||||
) -> str:
|
||||
if rate_limit_type == "key":
|
||||
return f"{user_api_key_dict.api_key}::{precise_minute}::{group}"
|
||||
elif rate_limit_type == "model_per_key":
|
||||
return f"{user_api_key_dict.api_key}::{data.get('model')}::{precise_minute}::{group}"
|
||||
elif rate_limit_type == "user":
|
||||
return f"{user_api_key_dict.user_id}::{precise_minute}::{group}"
|
||||
elif rate_limit_type == "customer":
|
||||
return f"{user_api_key_dict.end_user_id}::{precise_minute}::{group}"
|
||||
elif rate_limit_type == "team":
|
||||
return f"{user_api_key_dict.team_id}::{precise_minute}::{group}"
|
||||
else:
|
||||
raise ValueError(f"Invalid rate limit type: {rate_limit_type}")
|
||||
|
||||
async def check_key_in_limits_v2(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: dict,
|
||||
max_parallel_requests: int,
|
||||
precise_minute: str,
|
||||
tpm_limit: int,
|
||||
rpm_limit: int,
|
||||
current_rpm: int,
|
||||
current_tpm: int,
|
||||
current_requests: int,
|
||||
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
|
||||
):
|
||||
verbose_proxy_logger.info(
|
||||
f"Current Usage of {rate_limit_type} in this minute: {current_requests}, {current_tpm}, {current_rpm}"
|
||||
)
|
||||
if (
|
||||
current_requests >= max_parallel_requests
|
||||
or current_tpm >= tpm_limit
|
||||
or current_rpm >= rpm_limit
|
||||
):
|
||||
raise self.raise_rate_limit_error(
|
||||
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
|
||||
)
|
||||
|
||||
## INCREMENT CURRENT USAGE
|
||||
increment_list: List[Tuple[str, int]] = []
|
||||
for group in ["request_count", "rpm"]:
|
||||
key = self._get_current_usage_key(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
precise_minute=precise_minute,
|
||||
data=data,
|
||||
rate_limit_type=rate_limit_type,
|
||||
group=cast(RateLimitGroups, group),
|
||||
)
|
||||
increment_list.append((key, 1))
|
||||
|
||||
await self._increment_value_list_in_current_window(
|
||||
increment_list=increment_list,
|
||||
ttl=60,
|
||||
)
|
||||
|
||||
async def check_key_in_limits(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -150,6 +228,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
|
|||
request_count_user_id: Optional[str],
|
||||
request_count_team_id: Optional[str],
|
||||
request_count_end_user_id: Optional[str],
|
||||
rpm_api_key: Optional[str],
|
||||
tpm_api_key: Optional[str],
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> CacheObject:
|
||||
keys = [
|
||||
|
@ -159,6 +239,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
|
|||
request_count_user_id,
|
||||
request_count_team_id,
|
||||
request_count_end_user_id,
|
||||
rpm_api_key,
|
||||
tpm_api_key,
|
||||
]
|
||||
results = await self.internal_usage_cache.async_batch_get_cache(
|
||||
keys=keys,
|
||||
|
@ -173,6 +255,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
|
|||
request_count_user_id=None,
|
||||
request_count_team_id=None,
|
||||
request_count_end_user_id=None,
|
||||
rpm_api_key=None,
|
||||
tpm_api_key=None,
|
||||
)
|
||||
|
||||
return CacheObject(
|
||||
|
@ -182,6 +266,8 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
|
|||
request_count_user_id=results[3],
|
||||
request_count_team_id=results[4],
|
||||
request_count_end_user_id=results[5],
|
||||
rpm_api_key=results[6],
|
||||
tpm_api_key=results[7],
|
||||
)
|
||||
|
||||
async def async_pre_call_hook( # noqa: PLR0915
|
||||
|
@ -261,6 +347,12 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
|
|||
if api_key is not None
|
||||
else None
|
||||
),
|
||||
rpm_api_key=(
|
||||
f"{api_key}::{precise_minute}::rpm" if api_key is not None else None
|
||||
),
|
||||
tpm_api_key=(
|
||||
f"{api_key}::{precise_minute}::tpm" if api_key is not None else None
|
||||
),
|
||||
request_count_api_key_model=(
|
||||
f"{api_key}::{_model}::{precise_minute}::request_count"
|
||||
if api_key is not None and _model is not None
|
||||
|
@ -286,20 +378,33 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
|
|||
if api_key is not None:
|
||||
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
||||
# CHECK IF REQUEST ALLOWED for key
|
||||
await self.check_key_in_limits(
|
||||
await self.check_key_in_limits_v2(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=max_parallel_requests,
|
||||
current=cache_objects["request_count_api_key"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
precise_minute=precise_minute,
|
||||
tpm_limit=tpm_limit,
|
||||
rpm_limit=rpm_limit,
|
||||
current_rpm=cache_objects["rpm_api_key"] or 0,
|
||||
current_tpm=cache_objects["tpm_api_key"] or 0,
|
||||
current_requests=cache_objects["request_count_api_key"] or 0,
|
||||
rate_limit_type="key",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
# await self.check_key_in_limits(
|
||||
# user_api_key_dict=user_api_key_dict,
|
||||
# cache=cache,
|
||||
# data=data,
|
||||
# call_type=call_type,
|
||||
# max_parallel_requests=max_parallel_requests,
|
||||
# current=cache_objects["request_count_api_key"],
|
||||
# request_count_api_key=request_count_api_key,
|
||||
# tpm_limit=tpm_limit,
|
||||
# rpm_limit=rpm_limit,
|
||||
# rate_limit_type="key",
|
||||
# values_to_update_in_cache=values_to_update_in_cache,
|
||||
# )
|
||||
|
||||
# Check if request under RPM/TPM per model for a given API Key
|
||||
if (
|
||||
get_key_model_tpm_limit(user_api_key_dict) is not None
|
||||
|
@ -441,13 +546,13 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
|
|||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.internal_usage_cache.async_batch_set_cache(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
) # don't block execution for cache updates
|
||||
)
|
||||
# asyncio.create_task(
|
||||
# self.internal_usage_cache.async_batch_set_cache(
|
||||
# cache_list=values_to_update_in_cache,
|
||||
# ttl=60,
|
||||
# litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
# ) # don't block execution for cache updates
|
||||
# )
|
||||
|
||||
return
|
||||
|
||||
|
@ -810,6 +915,7 @@ class _PROXY_MaxParallelRequestsHandler(BaseRoutingStrategy, CustomLogger):
|
|||
"""
|
||||
Retrieve the key's remaining rate limits.
|
||||
"""
|
||||
return
|
||||
api_key = user_api_key_dict.api_key
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
|
|
|
@ -5,7 +5,7 @@ Base class across routing strategies to abstract commmon functions like batch in
|
|||
import asyncio
|
||||
import threading
|
||||
from abc import ABC
|
||||
from typing import List, Optional, Set, Union
|
||||
from typing import List, Optional, Set, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
|
@ -53,6 +53,15 @@ class BaseRoutingStrategy(ABC):
|
|||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _increment_value_list_in_current_window(
|
||||
self, increment_list: List[Tuple[str, int]], ttl: int
|
||||
):
|
||||
"""
|
||||
Increment a list of values in the current window
|
||||
"""
|
||||
for key, value in increment_list:
|
||||
await self._increment_value_in_current_window(key=key, value=value, ttl=ttl)
|
||||
|
||||
async def _increment_value_in_current_window(
|
||||
self, key: str, value: Union[int, float], ttl: int
|
||||
):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue