mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(parallel_request_limiter.py): fix user+team tpm/rpm limit check
Closes https://github.com/BerriAI/litellm/issues/3788
This commit is contained in:
parent
fa064c91fb
commit
4408b717f0
7 changed files with 157 additions and 532 deletions
|
@ -984,10 +984,6 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
||||||
|
|
||||||
org_id: Optional[str] = None # org id for a given key
|
org_id: Optional[str] = None # org id for a given key
|
||||||
|
|
||||||
# hidden params used for parallel request limiting, not required to create a token
|
|
||||||
user_id_rate_limits: Optional[dict] = None
|
|
||||||
team_id_rate_limits: Optional[dict] = None
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
|
@ -164,8 +164,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
# check if REQUEST ALLOWED for user_id
|
# check if REQUEST ALLOWED for user_id
|
||||||
user_id = user_api_key_dict.user_id
|
user_id = user_api_key_dict.user_id
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
|
_user_id_rate_limits = await self.user_api_key_cache.async_get_cache(
|
||||||
|
key=user_id
|
||||||
|
)
|
||||||
# get user tpm/rpm limits
|
# get user tpm/rpm limits
|
||||||
if _user_id_rate_limits is not None and isinstance(
|
if _user_id_rate_limits is not None and isinstance(
|
||||||
_user_id_rate_limits, dict
|
_user_id_rate_limits, dict
|
||||||
|
@ -196,13 +197,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
## get team tpm/rpm limits
|
## get team tpm/rpm limits
|
||||||
team_id = user_api_key_dict.team_id
|
team_id = user_api_key_dict.team_id
|
||||||
if team_id is not None:
|
if team_id is not None:
|
||||||
team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize)
|
team_tpm_limit = user_api_key_dict.team_tpm_limit
|
||||||
|
team_rpm_limit = user_api_key_dict.team_rpm_limit
|
||||||
if team_tpm_limit is None:
|
|
||||||
team_tpm_limit = sys.maxsize
|
|
||||||
team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize)
|
|
||||||
if team_rpm_limit is None:
|
|
||||||
team_rpm_limit = sys.maxsize
|
|
||||||
|
|
||||||
if team_tpm_limit is None:
|
if team_tpm_limit is None:
|
||||||
team_tpm_limit = sys.maxsize
|
team_tpm_limit = sys.maxsize
|
||||||
|
|
|
@ -1,379 +0,0 @@
|
||||||
# What is this?
|
|
||||||
## Checks TPM/RPM Limits for a key/user/team on the proxy
|
|
||||||
## Works with Redis - if given
|
|
||||||
|
|
||||||
from typing import Optional, Literal
|
|
||||||
import litellm, traceback, sys
|
|
||||||
from litellm.caching import DualCache, RedisCache
|
|
||||||
from litellm.proxy._types import (
|
|
||||||
UserAPIKeyAuth,
|
|
||||||
LiteLLM_VerificationTokenView,
|
|
||||||
LiteLLM_UserTable,
|
|
||||||
LiteLLM_TeamTable,
|
|
||||||
)
|
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
|
||||||
from litellm import ModelResponse
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
class _PROXY_MaxTPMRPMLimiter(CustomLogger):
|
|
||||||
user_api_key_cache = None
|
|
||||||
|
|
||||||
# Class variables or attributes
|
|
||||||
def __init__(self, internal_cache: Optional[DualCache]):
|
|
||||||
if internal_cache is None:
|
|
||||||
self.internal_cache = DualCache()
|
|
||||||
else:
|
|
||||||
self.internal_cache = internal_cache
|
|
||||||
|
|
||||||
def print_verbose(self, print_statement):
|
|
||||||
try:
|
|
||||||
verbose_proxy_logger.debug(print_statement)
|
|
||||||
if litellm.set_verbose:
|
|
||||||
print(print_statement) # noqa
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
## check if admin has set tpm/rpm limits for this key/user/team
|
|
||||||
|
|
||||||
def _check_limits_set(
|
|
||||||
self,
|
|
||||||
user_api_key_cache: DualCache,
|
|
||||||
key: Optional[str],
|
|
||||||
user_id: Optional[str],
|
|
||||||
team_id: Optional[str],
|
|
||||||
) -> bool:
|
|
||||||
## key
|
|
||||||
if key is not None:
|
|
||||||
key_val = user_api_key_cache.get_cache(key=key)
|
|
||||||
if isinstance(key_val, dict):
|
|
||||||
key_val = LiteLLM_VerificationTokenView(**key_val)
|
|
||||||
|
|
||||||
if isinstance(key_val, LiteLLM_VerificationTokenView):
|
|
||||||
user_api_key_tpm_limit = key_val.tpm_limit
|
|
||||||
|
|
||||||
user_api_key_rpm_limit = key_val.rpm_limit
|
|
||||||
|
|
||||||
if (
|
|
||||||
user_api_key_tpm_limit is not None
|
|
||||||
or user_api_key_rpm_limit is not None
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
## team
|
|
||||||
if team_id is not None:
|
|
||||||
team_val = user_api_key_cache.get_cache(key=team_id)
|
|
||||||
if isinstance(team_val, dict):
|
|
||||||
team_val = LiteLLM_TeamTable(**team_val)
|
|
||||||
|
|
||||||
if isinstance(team_val, LiteLLM_TeamTable):
|
|
||||||
team_tpm_limit = team_val.tpm_limit
|
|
||||||
|
|
||||||
team_rpm_limit = team_val.rpm_limit
|
|
||||||
|
|
||||||
if team_tpm_limit is not None or team_rpm_limit is not None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
## user
|
|
||||||
if user_id is not None:
|
|
||||||
user_val = user_api_key_cache.get_cache(key=user_id)
|
|
||||||
if isinstance(user_val, dict):
|
|
||||||
user_val = LiteLLM_UserTable(**user_val)
|
|
||||||
|
|
||||||
if isinstance(user_val, LiteLLM_UserTable):
|
|
||||||
user_tpm_limit = user_val.tpm_limit
|
|
||||||
|
|
||||||
user_rpm_limit = user_val.rpm_limit
|
|
||||||
|
|
||||||
if user_tpm_limit is not None or user_rpm_limit is not None:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def check_key_in_limits(
|
|
||||||
self,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
|
||||||
current_minute_dict: dict,
|
|
||||||
tpm_limit: int,
|
|
||||||
rpm_limit: int,
|
|
||||||
request_count_api_key: str,
|
|
||||||
type: Literal["key", "user", "team"],
|
|
||||||
):
|
|
||||||
|
|
||||||
if type == "key" and user_api_key_dict.api_key is not None:
|
|
||||||
current = current_minute_dict["key"].get(user_api_key_dict.api_key, None)
|
|
||||||
elif type == "user" and user_api_key_dict.user_id is not None:
|
|
||||||
current = current_minute_dict["user"].get(user_api_key_dict.user_id, None)
|
|
||||||
elif type == "team" and user_api_key_dict.team_id is not None:
|
|
||||||
current = current_minute_dict["team"].get(user_api_key_dict.team_id, None)
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
if current is None:
|
|
||||||
if tpm_limit == 0 or rpm_limit == 0:
|
|
||||||
# base case
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=429, detail="Max tpm/rpm limit reached."
|
|
||||||
)
|
|
||||||
elif current["current_tpm"] < tpm_limit and current["current_rpm"] < rpm_limit:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=429, detail="Max tpm/rpm limit reached.")
|
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
|
||||||
self,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
|
||||||
cache: DualCache,
|
|
||||||
data: dict,
|
|
||||||
call_type: str,
|
|
||||||
):
|
|
||||||
self.print_verbose(
|
|
||||||
f"Inside Max TPM/RPM Limiter Pre-Call Hook - {user_api_key_dict}"
|
|
||||||
)
|
|
||||||
api_key = user_api_key_dict.api_key
|
|
||||||
# check if REQUEST ALLOWED for user_id
|
|
||||||
user_id = user_api_key_dict.user_id
|
|
||||||
## get team tpm/rpm limits
|
|
||||||
team_id = user_api_key_dict.team_id
|
|
||||||
|
|
||||||
self.user_api_key_cache = cache
|
|
||||||
|
|
||||||
_set_limits = self._check_limits_set(
|
|
||||||
user_api_key_cache=cache, key=api_key, user_id=user_id, team_id=team_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self.print_verbose(f"_set_limits: {_set_limits}")
|
|
||||||
|
|
||||||
if _set_limits == False:
|
|
||||||
return
|
|
||||||
|
|
||||||
# ------------
|
|
||||||
# Setup values
|
|
||||||
# ------------
|
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
current_hour = datetime.now().strftime("%H")
|
|
||||||
current_minute = datetime.now().strftime("%M")
|
|
||||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
|
||||||
cache_key = "usage:{}".format(precise_minute)
|
|
||||||
current_minute_dict = await self.internal_cache.async_get_cache(
|
|
||||||
key=cache_key
|
|
||||||
) # {"usage:{curr_minute}": {"key": {<api_key>: {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}}}}
|
|
||||||
|
|
||||||
if current_minute_dict is None:
|
|
||||||
current_minute_dict = {"key": {}, "user": {}, "team": {}}
|
|
||||||
|
|
||||||
if api_key is not None:
|
|
||||||
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
|
||||||
if tpm_limit is None:
|
|
||||||
tpm_limit = sys.maxsize
|
|
||||||
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
|
|
||||||
if rpm_limit is None:
|
|
||||||
rpm_limit = sys.maxsize
|
|
||||||
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
|
||||||
await self.check_key_in_limits(
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
current_minute_dict=current_minute_dict,
|
|
||||||
request_count_api_key=request_count_api_key,
|
|
||||||
tpm_limit=tpm_limit,
|
|
||||||
rpm_limit=rpm_limit,
|
|
||||||
type="key",
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_id is not None:
|
|
||||||
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
|
|
||||||
|
|
||||||
# get user tpm/rpm limits
|
|
||||||
if _user_id_rate_limits is not None and isinstance(
|
|
||||||
_user_id_rate_limits, dict
|
|
||||||
):
|
|
||||||
user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
|
|
||||||
user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
|
|
||||||
if user_tpm_limit is None:
|
|
||||||
user_tpm_limit = sys.maxsize
|
|
||||||
if user_rpm_limit is None:
|
|
||||||
user_rpm_limit = sys.maxsize
|
|
||||||
|
|
||||||
# now do the same tpm/rpm checks
|
|
||||||
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
|
||||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
|
||||||
await self.check_key_in_limits(
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
current_minute_dict=current_minute_dict,
|
|
||||||
request_count_api_key=request_count_api_key,
|
|
||||||
tpm_limit=user_tpm_limit,
|
|
||||||
rpm_limit=user_rpm_limit,
|
|
||||||
type="user",
|
|
||||||
)
|
|
||||||
|
|
||||||
# TEAM RATE LIMITS
|
|
||||||
if team_id is not None:
|
|
||||||
team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize)
|
|
||||||
|
|
||||||
if team_tpm_limit is None:
|
|
||||||
team_tpm_limit = sys.maxsize
|
|
||||||
team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize)
|
|
||||||
if team_rpm_limit is None:
|
|
||||||
team_rpm_limit = sys.maxsize
|
|
||||||
|
|
||||||
if team_tpm_limit is None:
|
|
||||||
team_tpm_limit = sys.maxsize
|
|
||||||
if team_rpm_limit is None:
|
|
||||||
team_rpm_limit = sys.maxsize
|
|
||||||
|
|
||||||
# now do the same tpm/rpm checks
|
|
||||||
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
|
|
||||||
|
|
||||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
|
||||||
await self.check_key_in_limits(
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
current_minute_dict=current_minute_dict,
|
|
||||||
request_count_api_key=request_count_api_key,
|
|
||||||
tpm_limit=team_tpm_limit,
|
|
||||||
rpm_limit=team_rpm_limit,
|
|
||||||
type="team",
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
||||||
try:
|
|
||||||
self.print_verbose(f"INSIDE TPM RPM Limiter ASYNC SUCCESS LOGGING")
|
|
||||||
|
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
|
||||||
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
|
||||||
"user_api_key_user_id", None
|
|
||||||
)
|
|
||||||
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
|
|
||||||
"user_api_key_team_id", None
|
|
||||||
)
|
|
||||||
_limits_set = self._check_limits_set(
|
|
||||||
user_api_key_cache=self.user_api_key_cache,
|
|
||||||
key=user_api_key,
|
|
||||||
user_id=user_api_key_user_id,
|
|
||||||
team_id=user_api_key_team_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if _limits_set == False: # don't waste cache calls if no tpm/rpm limits set
|
|
||||||
return
|
|
||||||
|
|
||||||
# ------------
|
|
||||||
# Setup values
|
|
||||||
# ------------
|
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
current_hour = datetime.now().strftime("%H")
|
|
||||||
current_minute = datetime.now().strftime("%M")
|
|
||||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
|
||||||
|
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
if isinstance(response_obj, ModelResponse):
|
|
||||||
total_tokens = response_obj.usage.total_tokens
|
|
||||||
|
|
||||||
"""
|
|
||||||
- get value from redis
|
|
||||||
- increment requests + 1
|
|
||||||
- increment tpm + 1
|
|
||||||
- increment rpm + 1
|
|
||||||
- update value in-memory + redis
|
|
||||||
"""
|
|
||||||
cache_key = "usage:{}".format(precise_minute)
|
|
||||||
if (
|
|
||||||
self.internal_cache.redis_cache is not None
|
|
||||||
): # get straight from redis if possible
|
|
||||||
current_minute_dict = (
|
|
||||||
await self.internal_cache.redis_cache.async_get_cache(
|
|
||||||
key=cache_key,
|
|
||||||
)
|
|
||||||
) # {"usage:{current_minute}": {"key": {}, "team": {}, "user": {}}}
|
|
||||||
else:
|
|
||||||
current_minute_dict = await self.internal_cache.async_get_cache(
|
|
||||||
key=cache_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_minute_dict is None:
|
|
||||||
current_minute_dict = {"key": {}, "user": {}, "team": {}}
|
|
||||||
|
|
||||||
_cache_updated = False # check if a cache update is required. prevent unnecessary rewrites.
|
|
||||||
|
|
||||||
# ------------
|
|
||||||
# Update usage - API Key
|
|
||||||
# ------------
|
|
||||||
|
|
||||||
if user_api_key is not None:
|
|
||||||
_cache_updated = True
|
|
||||||
## API KEY ##
|
|
||||||
if user_api_key in current_minute_dict["key"]:
|
|
||||||
current_key_usage = current_minute_dict["key"][user_api_key]
|
|
||||||
new_val = {
|
|
||||||
"current_tpm": current_key_usage["current_tpm"] + total_tokens,
|
|
||||||
"current_rpm": current_key_usage["current_rpm"] + 1,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
new_val = {
|
|
||||||
"current_tpm": total_tokens,
|
|
||||||
"current_rpm": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
current_minute_dict["key"][user_api_key] = new_val
|
|
||||||
|
|
||||||
self.print_verbose(
|
|
||||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ------------
|
|
||||||
# Update usage - User
|
|
||||||
# ------------
|
|
||||||
if user_api_key_user_id is not None:
|
|
||||||
_cache_updated = True
|
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
if isinstance(response_obj, ModelResponse):
|
|
||||||
total_tokens = response_obj.usage.total_tokens
|
|
||||||
|
|
||||||
if user_api_key_user_id in current_minute_dict["key"]:
|
|
||||||
current_key_usage = current_minute_dict["key"][user_api_key_user_id]
|
|
||||||
new_val = {
|
|
||||||
"current_tpm": current_key_usage["current_tpm"] + total_tokens,
|
|
||||||
"current_rpm": current_key_usage["current_rpm"] + 1,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
new_val = {
|
|
||||||
"current_tpm": total_tokens,
|
|
||||||
"current_rpm": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
current_minute_dict["user"][user_api_key_user_id] = new_val
|
|
||||||
|
|
||||||
# ------------
|
|
||||||
# Update usage - Team
|
|
||||||
# ------------
|
|
||||||
if user_api_key_team_id is not None:
|
|
||||||
_cache_updated = True
|
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
if isinstance(response_obj, ModelResponse):
|
|
||||||
total_tokens = response_obj.usage.total_tokens
|
|
||||||
|
|
||||||
if user_api_key_team_id in current_minute_dict["key"]:
|
|
||||||
current_key_usage = current_minute_dict["key"][user_api_key_team_id]
|
|
||||||
new_val = {
|
|
||||||
"current_tpm": current_key_usage["current_tpm"] + total_tokens,
|
|
||||||
"current_rpm": current_key_usage["current_rpm"] + 1,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
new_val = {
|
|
||||||
"current_tpm": total_tokens,
|
|
||||||
"current_rpm": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
current_minute_dict["team"][user_api_key_team_id] = new_val
|
|
||||||
|
|
||||||
if _cache_updated == True:
|
|
||||||
await self.internal_cache.async_set_cache(
|
|
||||||
key=cache_key, value=current_minute_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.print_verbose("{}\n{}".format(e, traceback.format_exc())) # noqa
|
|
|
@ -397,6 +397,7 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
|
||||||
|
|
||||||
def get_custom_headers(
|
def get_custom_headers(
|
||||||
*,
|
*,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
model_id: Optional[str] = None,
|
model_id: Optional[str] = None,
|
||||||
cache_key: Optional[str] = None,
|
cache_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
|
@ -410,6 +411,8 @@ def get_custom_headers(
|
||||||
"x-litellm-model-api-base": api_base,
|
"x-litellm-model-api-base": api_base,
|
||||||
"x-litellm-version": version,
|
"x-litellm-version": version,
|
||||||
"x-litellm-model-region": model_region,
|
"x-litellm-model-region": model_region,
|
||||||
|
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
|
||||||
|
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
|
@ -4059,6 +4062,7 @@ async def chat_completion(
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
custom_headers = get_custom_headers(
|
custom_headers = get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -4078,6 +4082,7 @@ async def chat_completion(
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -4298,6 +4303,7 @@ async def completion(
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
custom_headers = get_custom_headers(
|
custom_headers = get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -4316,6 +4322,7 @@ async def completion(
|
||||||
)
|
)
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -4565,6 +4572,7 @@ async def embeddings(
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -4748,6 +4756,7 @@ async def image_generation(
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -4949,6 +4958,7 @@ async def audio_transcriptions(
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -5132,6 +5142,7 @@ async def moderations(
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
get_custom_headers(
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
|
|
@ -35,7 +35,6 @@ from litellm import (
|
||||||
)
|
)
|
||||||
from litellm.utils import ModelResponseIterator
|
from litellm.utils import ModelResponseIterator
|
||||||
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||||
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
|
|
||||||
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.proxy.db.base_client import CustomDB
|
from litellm.proxy.db.base_client import CustomDB
|
||||||
|
@ -81,9 +80,6 @@ class ProxyLogging:
|
||||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||||
self.internal_usage_cache = DualCache()
|
self.internal_usage_cache = DualCache()
|
||||||
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
|
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
|
||||||
self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(
|
|
||||||
internal_cache=self.internal_usage_cache
|
|
||||||
)
|
|
||||||
self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
|
self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
|
||||||
self.cache_control_check = _PROXY_CacheControlCheck()
|
self.cache_control_check = _PROXY_CacheControlCheck()
|
||||||
self.alerting: Optional[List] = None
|
self.alerting: Optional[List] = None
|
||||||
|
|
|
@ -1,162 +1,163 @@
|
||||||
|
### REPLACED BY 'test_parallel_request_limiter.py' ###
|
||||||
# What is this?
|
# What is this?
|
||||||
## Unit tests for the max tpm / rpm limiter hook for proxy
|
## Unit tests for the max tpm / rpm limiter hook for proxy
|
||||||
|
|
||||||
import sys, os, asyncio, time, random
|
# import sys, os, asyncio, time, random
|
||||||
from datetime import datetime
|
# from datetime import datetime
|
||||||
import traceback
|
# import traceback
|
||||||
from dotenv import load_dotenv
|
# from dotenv import load_dotenv
|
||||||
from typing import Optional
|
# from typing import Optional
|
||||||
|
|
||||||
load_dotenv()
|
# load_dotenv()
|
||||||
import os
|
# import os
|
||||||
|
|
||||||
sys.path.insert(
|
# sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
# 0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
# ) # Adds the parent directory to the system path
|
||||||
import pytest
|
# import pytest
|
||||||
import litellm
|
# import litellm
|
||||||
from litellm import Router
|
# from litellm import Router
|
||||||
from litellm.proxy.utils import ProxyLogging, hash_token
|
# from litellm.proxy.utils import ProxyLogging, hash_token
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
# from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache, RedisCache
|
# from litellm.caching import DualCache, RedisCache
|
||||||
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
|
# from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
|
||||||
from datetime import datetime
|
# from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
# @pytest.mark.asyncio
|
||||||
async def test_pre_call_hook_rpm_limits():
|
# async def test_pre_call_hook_rpm_limits():
|
||||||
"""
|
# """
|
||||||
Test if error raised on hitting rpm limits
|
# Test if error raised on hitting rpm limits
|
||||||
"""
|
# """
|
||||||
litellm.set_verbose = True
|
# litellm.set_verbose = True
|
||||||
_api_key = hash_token("sk-12345")
|
# _api_key = hash_token("sk-12345")
|
||||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1)
|
# user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1)
|
||||||
local_cache = DualCache()
|
# local_cache = DualCache()
|
||||||
# redis_usage_cache = RedisCache()
|
# # redis_usage_cache = RedisCache()
|
||||||
|
|
||||||
local_cache.set_cache(
|
# local_cache.set_cache(
|
||||||
key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1}
|
# key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1}
|
||||||
)
|
# )
|
||||||
|
|
||||||
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=DualCache())
|
# tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=DualCache())
|
||||||
|
|
||||||
await tpm_rpm_limiter.async_pre_call_hook(
|
# await tpm_rpm_limiter.async_pre_call_hook(
|
||||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
# user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||||
)
|
# )
|
||||||
|
|
||||||
kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}}
|
# kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}}
|
||||||
|
|
||||||
await tpm_rpm_limiter.async_log_success_event(
|
# await tpm_rpm_limiter.async_log_success_event(
|
||||||
kwargs=kwargs,
|
# kwargs=kwargs,
|
||||||
response_obj="",
|
# response_obj="",
|
||||||
start_time="",
|
# start_time="",
|
||||||
end_time="",
|
# end_time="",
|
||||||
)
|
# )
|
||||||
|
|
||||||
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
# ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
await tpm_rpm_limiter.async_pre_call_hook(
|
# await tpm_rpm_limiter.async_pre_call_hook(
|
||||||
user_api_key_dict=user_api_key_dict,
|
# user_api_key_dict=user_api_key_dict,
|
||||||
cache=local_cache,
|
# cache=local_cache,
|
||||||
data={},
|
# data={},
|
||||||
call_type="",
|
# call_type="",
|
||||||
)
|
# )
|
||||||
|
|
||||||
pytest.fail(f"Expected call to fail")
|
# pytest.fail(f"Expected call to fail")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
assert e.status_code == 429
|
# assert e.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
# @pytest.mark.asyncio
|
||||||
async def test_pre_call_hook_team_rpm_limits(
|
# async def test_pre_call_hook_team_rpm_limits(
|
||||||
_redis_usage_cache: Optional[RedisCache] = None,
|
# _redis_usage_cache: Optional[RedisCache] = None,
|
||||||
):
|
# ):
|
||||||
"""
|
# """
|
||||||
Test if error raised on hitting team rpm limits
|
# Test if error raised on hitting team rpm limits
|
||||||
"""
|
# """
|
||||||
litellm.set_verbose = True
|
# litellm.set_verbose = True
|
||||||
_api_key = "sk-12345"
|
# _api_key = "sk-12345"
|
||||||
_team_id = "unique-team-id"
|
# _team_id = "unique-team-id"
|
||||||
_user_api_key_dict = {
|
# _user_api_key_dict = {
|
||||||
"api_key": _api_key,
|
# "api_key": _api_key,
|
||||||
"max_parallel_requests": 1,
|
# "max_parallel_requests": 1,
|
||||||
"tpm_limit": 9,
|
# "tpm_limit": 9,
|
||||||
"rpm_limit": 10,
|
# "rpm_limit": 10,
|
||||||
"team_rpm_limit": 1,
|
# "team_rpm_limit": 1,
|
||||||
"team_id": _team_id,
|
# "team_id": _team_id,
|
||||||
}
|
# }
|
||||||
user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore
|
# user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore
|
||||||
_api_key = hash_token(_api_key)
|
# _api_key = hash_token(_api_key)
|
||||||
local_cache = DualCache()
|
# local_cache = DualCache()
|
||||||
local_cache.set_cache(key=_api_key, value=_user_api_key_dict)
|
# local_cache.set_cache(key=_api_key, value=_user_api_key_dict)
|
||||||
internal_cache = DualCache(redis_cache=_redis_usage_cache)
|
# internal_cache = DualCache(redis_cache=_redis_usage_cache)
|
||||||
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=internal_cache)
|
# tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=internal_cache)
|
||||||
await tpm_rpm_limiter.async_pre_call_hook(
|
# await tpm_rpm_limiter.async_pre_call_hook(
|
||||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
# user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||||
)
|
# )
|
||||||
|
|
||||||
kwargs = {
|
# kwargs = {
|
||||||
"litellm_params": {
|
# "litellm_params": {
|
||||||
"metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id}
|
# "metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id}
|
||||||
}
|
# }
|
||||||
}
|
# }
|
||||||
|
|
||||||
await tpm_rpm_limiter.async_log_success_event(
|
# await tpm_rpm_limiter.async_log_success_event(
|
||||||
kwargs=kwargs,
|
# kwargs=kwargs,
|
||||||
response_obj="",
|
# response_obj="",
|
||||||
start_time="",
|
# start_time="",
|
||||||
end_time="",
|
# end_time="",
|
||||||
)
|
# )
|
||||||
|
|
||||||
print(f"local_cache: {local_cache}")
|
# print(f"local_cache: {local_cache}")
|
||||||
|
|
||||||
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
# ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
await tpm_rpm_limiter.async_pre_call_hook(
|
# await tpm_rpm_limiter.async_pre_call_hook(
|
||||||
user_api_key_dict=user_api_key_dict,
|
# user_api_key_dict=user_api_key_dict,
|
||||||
cache=local_cache,
|
# cache=local_cache,
|
||||||
data={},
|
# data={},
|
||||||
call_type="",
|
# call_type="",
|
||||||
)
|
# )
|
||||||
|
|
||||||
pytest.fail(f"Expected call to fail")
|
# pytest.fail(f"Expected call to fail")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
assert e.status_code == 429 # type: ignore
|
# assert e.status_code == 429 # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
# @pytest.mark.asyncio
|
||||||
async def test_namespace():
|
# async def test_namespace():
|
||||||
"""
|
# """
|
||||||
- test if default namespace set via `proxyconfig._init_cache`
|
# - test if default namespace set via `proxyconfig._init_cache`
|
||||||
- respected for tpm/rpm caching
|
# - respected for tpm/rpm caching
|
||||||
"""
|
# """
|
||||||
from litellm.proxy.proxy_server import ProxyConfig
|
# from litellm.proxy.proxy_server import ProxyConfig
|
||||||
|
|
||||||
redis_usage_cache: Optional[RedisCache] = None
|
# redis_usage_cache: Optional[RedisCache] = None
|
||||||
cache_params = {"type": "redis", "namespace": "litellm_default"}
|
# cache_params = {"type": "redis", "namespace": "litellm_default"}
|
||||||
|
|
||||||
## INIT CACHE ##
|
# ## INIT CACHE ##
|
||||||
proxy_config = ProxyConfig()
|
# proxy_config = ProxyConfig()
|
||||||
setattr(litellm.proxy.proxy_server, "proxy_config", proxy_config)
|
# setattr(litellm.proxy.proxy_server, "proxy_config", proxy_config)
|
||||||
|
|
||||||
proxy_config._init_cache(cache_params=cache_params)
|
# proxy_config._init_cache(cache_params=cache_params)
|
||||||
|
|
||||||
redis_cache: Optional[RedisCache] = getattr(
|
# redis_cache: Optional[RedisCache] = getattr(
|
||||||
litellm.proxy.proxy_server, "redis_usage_cache"
|
# litellm.proxy.proxy_server, "redis_usage_cache"
|
||||||
)
|
# )
|
||||||
|
|
||||||
## CHECK IF NAMESPACE SET ##
|
# ## CHECK IF NAMESPACE SET ##
|
||||||
assert redis_cache.namespace == "litellm_default"
|
# assert redis_cache.namespace == "litellm_default"
|
||||||
|
|
||||||
## CHECK IF TPM/RPM RATE LIMITING WORKS ##
|
# ## CHECK IF TPM/RPM RATE LIMITING WORKS ##
|
||||||
await test_pre_call_hook_team_rpm_limits(_redis_usage_cache=redis_cache)
|
# await test_pre_call_hook_team_rpm_limits(_redis_usage_cache=redis_cache)
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
# current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
current_hour = datetime.now().strftime("%H")
|
# current_hour = datetime.now().strftime("%H")
|
||||||
current_minute = datetime.now().strftime("%M")
|
# current_minute = datetime.now().strftime("%M")
|
||||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
# precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||||
|
|
||||||
cache_key = "litellm_default:usage:{}".format(precise_minute)
|
# cache_key = "litellm_default:usage:{}".format(precise_minute)
|
||||||
value = await redis_cache.async_get_cache(key=cache_key)
|
# value = await redis_cache.async_get_cache(key=cache_key)
|
||||||
assert value is not None
|
# assert value is not None
|
||||||
|
|
|
@ -229,17 +229,21 @@ async def test_pre_call_hook_user_tpm_limits():
|
||||||
"""
|
"""
|
||||||
Test if error raised on hitting tpm limits
|
Test if error raised on hitting tpm limits
|
||||||
"""
|
"""
|
||||||
|
local_cache = DualCache()
|
||||||
# create user with tpm/rpm limits
|
# create user with tpm/rpm limits
|
||||||
|
user_id = "test-user"
|
||||||
|
user_obj = {"tpm_limit": 9, "rpm_limit": 10}
|
||||||
|
|
||||||
|
local_cache.set_cache(key=user_id, value=user_obj)
|
||||||
|
|
||||||
_api_key = "sk-12345"
|
_api_key = "sk-12345"
|
||||||
user_api_key_dict = UserAPIKeyAuth(
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
api_key=_api_key,
|
api_key=_api_key,
|
||||||
user_id="ishaan",
|
user_id=user_id,
|
||||||
user_id_rate_limits={"tpm_limit": 9, "rpm_limit": 10},
|
|
||||||
)
|
)
|
||||||
res = dict(user_api_key_dict)
|
res = dict(user_api_key_dict)
|
||||||
print("dict user", res)
|
print("dict user", res)
|
||||||
local_cache = DualCache()
|
|
||||||
parallel_request_handler = MaxParallelRequestsHandler()
|
parallel_request_handler = MaxParallelRequestsHandler()
|
||||||
|
|
||||||
await parallel_request_handler.async_pre_call_hook(
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
@ -248,7 +252,7 @@ async def test_pre_call_hook_user_tpm_limits():
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {"user_api_key_user_id": "ishaan", "user_api_key": "gm"}
|
"metadata": {"user_api_key_user_id": user_id, "user_api_key": "gm"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -734,7 +738,7 @@ async def test_bad_router_call():
|
||||||
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
request_count_api_key = f"{_api_key}::{precise_minute}::request_count"
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
parallel_request_handler.user_api_key_cache.get_cache(
|
parallel_request_handler.user_api_key_cache.get_cache( # type: ignore
|
||||||
key=request_count_api_key
|
key=request_count_api_key
|
||||||
)["current_requests"]
|
)["current_requests"]
|
||||||
== 1
|
== 1
|
||||||
|
@ -751,7 +755,7 @@ async def test_bad_router_call():
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
assert (
|
assert (
|
||||||
parallel_request_handler.user_api_key_cache.get_cache(
|
parallel_request_handler.user_api_key_cache.get_cache( # type: ignore
|
||||||
key=request_count_api_key
|
key=request_count_api_key
|
||||||
)["current_requests"]
|
)["current_requests"]
|
||||||
== 0
|
== 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue