mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(parallel_request_limiter.py): use redis cache, if available for rate limiting across instances
Fixes https://github.com/BerriAI/litellm/issues/4148
This commit is contained in:
parent
c059352908
commit
76c9b715f2
4 changed files with 75 additions and 56 deletions
|
@ -10,18 +10,17 @@ from datetime import datetime
|
|||
|
||||
|
||||
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||
user_api_key_cache = None
|
||||
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, internal_usage_cache: DualCache):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def check_key_in_limits(
|
||||
|
@ -35,7 +34,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
rpm_limit: int,
|
||||
request_count_api_key: str,
|
||||
):
|
||||
current = cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||
if current is None:
|
||||
|
@ -49,7 +48,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
cache.set_cache(request_count_api_key, new_val)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val
|
||||
)
|
||||
elif (
|
||||
int(current["current_requests"]) < max_parallel_requests
|
||||
and current["current_tpm"] < tpm_limit
|
||||
|
@ -61,7 +62,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
"current_tpm": current["current_tpm"],
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
cache.set_cache(request_count_api_key, new_val)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
|
@ -75,7 +78,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||
self.print_verbose("Inside Max Parallel Request Pre-Call Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||
if max_parallel_requests is None:
|
||||
|
@ -90,7 +93,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if rpm_limit is None:
|
||||
rpm_limit = sys.maxsize
|
||||
|
||||
self.user_api_key_cache = cache # save the api key cache for updating the value
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
@ -98,7 +100,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
current_global_requests = await cache.async_get_cache(
|
||||
current_global_requests = await self.internal_usage_cache.async_get_cache(
|
||||
key=_key, local_only=True
|
||||
)
|
||||
# check if below limit
|
||||
|
@ -111,7 +113,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
)
|
||||
# if below -> increment
|
||||
else:
|
||||
await cache.async_increment_cache(key=_key, value=1, local_only=True)
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key, value=1, local_only=True
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
|
@ -123,7 +127,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
|
||||
# CHECK IF REQUEST ALLOWED for key
|
||||
|
||||
current = cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||
self.print_verbose(f"current: {current}")
|
||||
|
@ -143,7 +147,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
cache.set_cache(request_count_api_key, new_val)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val
|
||||
)
|
||||
elif (
|
||||
int(current["current_requests"]) < max_parallel_requests
|
||||
and current["current_tpm"] < tpm_limit
|
||||
|
@ -155,7 +161,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
"current_tpm": current["current_tpm"],
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
cache.set_cache(request_count_api_key, new_val)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Max parallel request limit reached."
|
||||
|
@ -164,7 +172,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# check if REQUEST ALLOWED for user_id
|
||||
user_id = user_api_key_dict.user_id
|
||||
if user_id is not None:
|
||||
_user_id_rate_limits = await self.user_api_key_cache.async_get_cache(
|
||||
_user_id_rate_limits = await self.internal_usage_cache.async_get_cache(
|
||||
key=user_id
|
||||
)
|
||||
# get user tpm/rpm limits
|
||||
|
@ -256,7 +264,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
|
@ -269,9 +277,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
)
|
||||
user_api_key_end_user_id = kwargs.get("user")
|
||||
|
||||
if self.user_api_key_cache is None:
|
||||
return
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
@ -280,7 +285,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
# decrement
|
||||
await self.user_api_key_cache.async_increment_cache(
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key, value=-1, local_only=True
|
||||
)
|
||||
|
||||
|
@ -303,7 +308,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
f"{user_api_key}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = self.user_api_key_cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
|
@ -320,7 +325,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
self.user_api_key_cache.set_cache(
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val, ttl=60
|
||||
) # store in cache for 1 min.
|
||||
|
||||
|
@ -337,7 +342,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
f"{user_api_key_user_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = self.user_api_key_cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
|
@ -354,7 +359,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
self.user_api_key_cache.set_cache(
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val, ttl=60
|
||||
) # store in cache for 1 min.
|
||||
|
||||
|
@ -371,7 +376,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
f"{user_api_key_team_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = self.user_api_key_cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
|
@ -388,7 +393,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
self.user_api_key_cache.set_cache(
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val, ttl=60
|
||||
) # store in cache for 1 min.
|
||||
|
||||
|
@ -405,7 +410,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
f"{user_api_key_end_user_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = self.user_api_key_cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
|
@ -422,7 +427,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
self.user_api_key_cache.set_cache(
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val, ttl=60
|
||||
) # store in cache for 1 min.
|
||||
|
||||
|
@ -442,9 +447,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if user_api_key is None:
|
||||
return
|
||||
|
||||
if self.user_api_key_cache is None:
|
||||
return
|
||||
|
||||
## decrement call count if call failed
|
||||
if "Max parallel request limit reached" in str(kwargs["exception"]):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
|
@ -457,12 +459,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
current_global_requests = (
|
||||
await self.user_api_key_cache.async_get_cache(
|
||||
await self.internal_usage_cache.async_get_cache(
|
||||
key=_key, local_only=True
|
||||
)
|
||||
)
|
||||
# decrement
|
||||
await self.user_api_key_cache.async_increment_cache(
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key, value=-1, local_only=True
|
||||
)
|
||||
|
||||
|
@ -478,7 +480,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
current = self.user_api_key_cache.get_cache(
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
|
@ -493,7 +495,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
}
|
||||
|
||||
self.print_verbose(f"updated_value in failure call: {new_val}")
|
||||
self.user_api_key_cache.set_cache(
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key, new_val, ttl=60
|
||||
) # save in cache for up to 1 min.
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue