mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(parallel_request_limiter.py): use redis cache, if available for rate limiting across instances
Fixes https://github.com/BerriAI/litellm/issues/4148
This commit is contained in:
parent
c059352908
commit
76c9b715f2
4 changed files with 75 additions and 56 deletions
|
@ -68,6 +68,7 @@ router_settings:
|
|||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
failure_callback: ["langfuse"]
|
||||
cache: true
|
||||
|
||||
# general_settings:
|
||||
# alerting: ["email"]
|
||||
|
|
|
@ -10,18 +10,17 @@ from datetime import datetime
|
|||
|
||||
|
||||
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||
user_api_key_cache = None
|
||||
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, internal_usage_cache: DualCache):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def check_key_in_limits(
|
||||
|
@ -35,7 +34,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
rpm_limit: int,
|
||||
request_count_api_key: str,
|
||||
):
|
||||
current = cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||
if current is None:
|
||||
|
@ -49,7 +48,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
cache.set_cache(request_count_api_key, new_val)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val
|
||||
)
|
||||
elif (
|
||||
int(current["current_requests"]) < max_parallel_requests
|
||||
and current["current_tpm"] < tpm_limit
|
||||
|
@ -61,7 +62,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
"current_tpm": current["current_tpm"],
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
cache.set_cache(request_count_api_key, new_val)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
|
@ -75,7 +78,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||
self.print_verbose("Inside Max Parallel Request Pre-Call Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||
if max_parallel_requests is None:
|
||||
|
@ -90,7 +93,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if rpm_limit is None:
|
||||
rpm_limit = sys.maxsize
|
||||
|
||||
self.user_api_key_cache = cache # save the api key cache for updating the value
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
@ -98,7 +100,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
current_global_requests = await cache.async_get_cache(
|
||||
current_global_requests = await self.internal_usage_cache.async_get_cache(
|
||||
key=_key, local_only=True
|
||||
)
|
||||
# check if below limit
|
||||
|
@ -111,7 +113,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
)
|
||||
# if below -> increment
|
||||
else:
|
||||
await cache.async_increment_cache(key=_key, value=1, local_only=True)
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key, value=1, local_only=True
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
|
@ -123,7 +127,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
|
||||
# CHECK IF REQUEST ALLOWED for key
|
||||
|
||||
current = cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||
self.print_verbose(f"current: {current}")
|
||||
|
@ -143,7 +147,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
cache.set_cache(request_count_api_key, new_val)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val
|
||||
)
|
||||
elif (
|
||||
int(current["current_requests"]) < max_parallel_requests
|
||||
and current["current_tpm"] < tpm_limit
|
||||
|
@ -155,7 +161,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
"current_tpm": current["current_tpm"],
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
cache.set_cache(request_count_api_key, new_val)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Max parallel request limit reached."
|
||||
|
@ -164,7 +172,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# check if REQUEST ALLOWED for user_id
|
||||
user_id = user_api_key_dict.user_id
|
||||
if user_id is not None:
|
||||
_user_id_rate_limits = await self.user_api_key_cache.async_get_cache(
|
||||
_user_id_rate_limits = await self.internal_usage_cache.async_get_cache(
|
||||
key=user_id
|
||||
)
|
||||
# get user tpm/rpm limits
|
||||
|
@ -256,7 +264,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
|
@ -269,9 +277,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
)
|
||||
user_api_key_end_user_id = kwargs.get("user")
|
||||
|
||||
if self.user_api_key_cache is None:
|
||||
return
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
@ -280,7 +285,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
# decrement
|
||||
await self.user_api_key_cache.async_increment_cache(
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key, value=-1, local_only=True
|
||||
)
|
||||
|
||||
|
@ -303,7 +308,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
f"{user_api_key}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = self.user_api_key_cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
|
@ -320,7 +325,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
self.user_api_key_cache.set_cache(
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val, ttl=60
|
||||
) # store in cache for 1 min.
|
||||
|
||||
|
@ -337,7 +342,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
f"{user_api_key_user_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = self.user_api_key_cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
|
@ -354,7 +359,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
self.user_api_key_cache.set_cache(
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val, ttl=60
|
||||
) # store in cache for 1 min.
|
||||
|
||||
|
@ -371,7 +376,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
f"{user_api_key_team_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = self.user_api_key_cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
|
@ -388,7 +393,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
self.user_api_key_cache.set_cache(
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val, ttl=60
|
||||
) # store in cache for 1 min.
|
||||
|
||||
|
@ -405,7 +410,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
f"{user_api_key_end_user_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = self.user_api_key_cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
|
@ -422,7 +427,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
self.user_api_key_cache.set_cache(
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val, ttl=60
|
||||
) # store in cache for 1 min.
|
||||
|
||||
|
@ -442,9 +447,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if user_api_key is None:
|
||||
return
|
||||
|
||||
if self.user_api_key_cache is None:
|
||||
return
|
||||
|
||||
## decrement call count if call failed
|
||||
if "Max parallel request limit reached" in str(kwargs["exception"]):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
|
@ -457,12 +459,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
current_global_requests = (
|
||||
await self.user_api_key_cache.async_get_cache(
|
||||
await self.internal_usage_cache.async_get_cache(
|
||||
key=_key, local_only=True
|
||||
)
|
||||
)
|
||||
# decrement
|
||||
await self.user_api_key_cache.async_increment_cache(
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key, value=-1, local_only=True
|
||||
)
|
||||
|
||||
|
@ -478,7 +480,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
current = self.user_api_key_cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
|
@ -493,7 +495,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
}
|
||||
|
||||
self.print_verbose(f"updated_value in failure call: {new_val}")
|
||||
self.user_api_key_cache.set_cache(
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val, ttl=60
|
||||
) # save in cache for up to 1 min.
|
||||
except Exception as e:
|
||||
|
|
|
@ -2852,6 +2852,7 @@ class ProxyConfig:
|
|||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||
### ALERTING ###
|
||||
|
||||
proxy_logging_obj.update_values(
|
||||
alerting=general_settings.get("alerting", None),
|
||||
alerting_threshold=general_settings.get("alerting_threshold", 600),
|
||||
|
@ -3963,6 +3964,11 @@ async def startup_event():
|
|||
|
||||
db_writer_client = HTTPHandler()
|
||||
|
||||
## UPDATE INTERNAL USAGE CACHE ##
|
||||
proxy_logging_obj.update_values(
|
||||
redis_cache=redis_usage_cache
|
||||
) # used by parallel request limiter for rate limiting keys across instances
|
||||
|
||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types:
|
||||
|
|
|
@ -162,8 +162,12 @@ class ProxyLogging:
|
|||
## INITIALIZE LITELLM CALLBACKS ##
|
||||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
self.internal_usage_cache = DualCache()
|
||||
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
|
||||
self.internal_usage_cache = DualCache(
|
||||
default_in_memory_ttl=1
|
||||
) # ping redis cache every 1s
|
||||
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler(
|
||||
self.internal_usage_cache
|
||||
)
|
||||
self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
|
||||
self.cache_control_check = _PROXY_CacheControlCheck()
|
||||
self.alerting: Optional[List] = None
|
||||
|
@ -189,39 +193,45 @@ class ProxyLogging:
|
|||
|
||||
def update_values(
|
||||
self,
|
||||
alerting: Optional[List],
|
||||
alerting_threshold: Optional[float],
|
||||
redis_cache: Optional[RedisCache],
|
||||
alerting: Optional[List] = None,
|
||||
alerting_threshold: Optional[float] = None,
|
||||
redis_cache: Optional[RedisCache] = None,
|
||||
alert_types: Optional[List[AlertType]] = None,
|
||||
alerting_args: Optional[dict] = None,
|
||||
):
|
||||
self.alerting = alerting
|
||||
updated_slack_alerting: bool = False
|
||||
if self.alerting is not None:
|
||||
self.alerting = alerting
|
||||
updated_slack_alerting = True
|
||||
if alerting_threshold is not None:
|
||||
self.alerting_threshold = alerting_threshold
|
||||
updated_slack_alerting = True
|
||||
if alert_types is not None:
|
||||
self.alert_types = alert_types
|
||||
updated_slack_alerting = True
|
||||
|
||||
self.slack_alerting_instance.update_values(
|
||||
alerting=self.alerting,
|
||||
alerting_threshold=self.alerting_threshold,
|
||||
alert_types=self.alert_types,
|
||||
alerting_args=alerting_args,
|
||||
)
|
||||
if updated_slack_alerting is True:
|
||||
self.slack_alerting_instance.update_values(
|
||||
alerting=self.alerting,
|
||||
alerting_threshold=self.alerting_threshold,
|
||||
alert_types=self.alert_types,
|
||||
alerting_args=alerting_args,
|
||||
)
|
||||
|
||||
if (
|
||||
self.alerting is not None
|
||||
and "slack" in self.alerting
|
||||
and "daily_reports" in self.alert_types
|
||||
):
|
||||
# NOTE: ENSURE we only add callbacks when alerting is on
|
||||
# We should NOT add callbacks when alerting is off
|
||||
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
|
||||
if (
|
||||
self.alerting is not None
|
||||
and "slack" in self.alerting
|
||||
and "daily_reports" in self.alert_types
|
||||
):
|
||||
# NOTE: ENSURE we only add callbacks when alerting is on
|
||||
# We should NOT add callbacks when alerting is off
|
||||
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
|
||||
|
||||
if redis_cache is not None:
|
||||
self.internal_usage_cache.redis_cache = redis_cache
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||
print_verbose("INITIALIZING LITELLM CALLBACKS!")
|
||||
self.service_logging_obj = ServiceLogging()
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
litellm.callbacks.append(self.max_budget_limiter)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue