fix(parallel_request_limiter.py): use redis cache, if available for rate limiting across instances

Fixes https://github.com/BerriAI/litellm/issues/4148
This commit is contained in:
Krrish Dholakia 2024-06-12 10:35:48 -07:00
parent 408a18d433
commit 77328e4a28
4 changed files with 75 additions and 56 deletions

View file

@ -10,18 +10,17 @@ from datetime import datetime
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
user_api_key_cache = None
# Class variables or attributes
def __init__(self):
pass
def __init__(self, internal_usage_cache: DualCache):
self.internal_usage_cache = internal_usage_cache
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
except Exception:
pass
async def check_key_in_limits(
@ -35,7 +34,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
rpm_limit: int,
request_count_api_key: str,
):
current = cache.get_cache(
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
if current is None:
@ -49,7 +48,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_tpm": 0,
"current_rpm": 0,
}
cache.set_cache(request_count_api_key, new_val)
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val
)
elif (
int(current["current_requests"]) < max_parallel_requests
and current["current_tpm"] < tpm_limit
@ -61,7 +62,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"],
}
cache.set_cache(request_count_api_key, new_val)
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val
)
else:
raise HTTPException(
status_code=429,
@ -75,7 +78,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
data: dict,
call_type: str,
):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
self.print_verbose("Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests
if max_parallel_requests is None:
@ -90,7 +93,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if rpm_limit is None:
rpm_limit = sys.maxsize
self.user_api_key_cache = cache # save the api key cache for updating the value
# ------------
# Setup values
# ------------
@ -98,7 +100,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
current_global_requests = await cache.async_get_cache(
current_global_requests = await self.internal_usage_cache.async_get_cache(
key=_key, local_only=True
)
# check if below limit
@ -111,7 +113,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
)
# if below -> increment
else:
await cache.async_increment_cache(key=_key, value=1, local_only=True)
await self.internal_usage_cache.async_increment_cache(
key=_key, value=1, local_only=True
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
@ -123,7 +127,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# CHECK IF REQUEST ALLOWED for key
current = cache.get_cache(
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
self.print_verbose(f"current: {current}")
@ -143,7 +147,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_tpm": 0,
"current_rpm": 0,
}
cache.set_cache(request_count_api_key, new_val)
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val
)
elif (
int(current["current_requests"]) < max_parallel_requests
and current["current_tpm"] < tpm_limit
@ -155,7 +161,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"],
}
cache.set_cache(request_count_api_key, new_val)
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val
)
else:
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
@ -164,7 +172,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# check if REQUEST ALLOWED for user_id
user_id = user_api_key_dict.user_id
if user_id is not None:
_user_id_rate_limits = await self.user_api_key_cache.async_get_cache(
_user_id_rate_limits = await self.internal_usage_cache.async_get_cache(
key=user_id
)
# get user tpm/rpm limits
@ -256,7 +264,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
"global_max_parallel_requests", None
)
@ -269,9 +277,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
)
user_api_key_end_user_id = kwargs.get("user")
if self.user_api_key_cache is None:
return
# ------------
# Setup values
# ------------
@ -280,7 +285,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# get value from cache
_key = "global_max_parallel_requests"
# decrement
await self.user_api_key_cache.async_increment_cache(
await self.internal_usage_cache.async_increment_cache(
key=_key, value=-1, local_only=True
)
@ -303,7 +308,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
f"{user_api_key}::{precise_minute}::request_count"
)
current = self.user_api_key_cache.get_cache(
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key
) or {
"current_requests": 1,
@ -320,7 +325,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
self.user_api_key_cache.set_cache(
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60
) # store in cache for 1 min.
@ -337,7 +342,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
f"{user_api_key_user_id}::{precise_minute}::request_count"
)
current = self.user_api_key_cache.get_cache(
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key
) or {
"current_requests": 1,
@ -354,7 +359,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
self.user_api_key_cache.set_cache(
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60
) # store in cache for 1 min.
@ -371,7 +376,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
f"{user_api_key_team_id}::{precise_minute}::request_count"
)
current = self.user_api_key_cache.get_cache(
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key
) or {
"current_requests": 1,
@ -388,7 +393,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
self.user_api_key_cache.set_cache(
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60
) # store in cache for 1 min.
@ -405,7 +410,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
f"{user_api_key_end_user_id}::{precise_minute}::request_count"
)
current = self.user_api_key_cache.get_cache(
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key
) or {
"current_requests": 1,
@ -422,7 +427,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
self.user_api_key_cache.set_cache(
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60
) # store in cache for 1 min.
@ -442,9 +447,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if user_api_key is None:
return
if self.user_api_key_cache is None:
return
## decrement call count if call failed
if "Max parallel request limit reached" in str(kwargs["exception"]):
pass # ignore failed calls due to max limit being reached
@ -457,12 +459,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# get value from cache
_key = "global_max_parallel_requests"
current_global_requests = (
await self.user_api_key_cache.async_get_cache(
await self.internal_usage_cache.async_get_cache(
key=_key, local_only=True
)
)
# decrement
await self.user_api_key_cache.async_increment_cache(
await self.internal_usage_cache.async_increment_cache(
key=_key, value=-1, local_only=True
)
@ -478,7 +480,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# ------------
# Update usage
# ------------
current = self.user_api_key_cache.get_cache(
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key
) or {
"current_requests": 1,
@ -493,7 +495,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
}
self.print_verbose(f"updated_value in failure call: {new_val}")
self.user_api_key_cache.set_cache(
await self.internal_usage_cache.async_set_cache(
request_count_api_key, new_val, ttl=60
) # save in cache for up to 1 min.
except Exception as e: