fix(parallel_request_limiter.py): fix max parallel request limiter on retries

This commit is contained in:
Krrish Dholakia 2024-05-15 20:16:11 -07:00
parent 153ce0d085
commit 594ca947c8
4 changed files with 100 additions and 6 deletions

View file

@ -79,6 +79,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
max_parallel_requests = user_api_key_dict.max_parallel_requests
if max_parallel_requests is None:
max_parallel_requests = sys.maxsize
global_max_parallel_requests = data.get("metadata", {}).get(
"global_max_parallel_requests", None
)
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
if tpm_limit is None:
tpm_limit = sys.maxsize
@ -91,6 +94,24 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Setup values
# ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
current_global_requests = await cache.async_get_cache(
key=_key, local_only=True
)
# check if below limit
if current_global_requests is None:
current_global_requests = 1
# if above -> raise error
if current_global_requests >= global_max_parallel_requests:
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
# if below -> increment
else:
await cache.async_increment_cache(key=_key, value=1, local_only=True)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
@ -207,6 +228,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
"global_max_parallel_requests", None
)
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
@ -222,6 +246,14 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Setup values
# ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
# decrement
await self.user_api_key_cache.async_increment_cache(
key=_key, value=-1, local_only=True
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
@ -336,6 +368,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
"global_max_parallel_requests", None
)
user_api_key = (
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
)
@ -347,17 +382,26 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
return
## decrement call count if call failed
if (
hasattr(kwargs["exception"], "status_code")
and kwargs["exception"].status_code == 429
and "Max parallel request limit reached" in str(kwargs["exception"])
):
if "Max parallel request limit reached" in str(kwargs["exception"]):
pass # ignore failed calls due to max limit being reached
else:
# ------------
# Setup values
# ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
current_global_requests = (
await self.user_api_key_cache.async_get_cache(
key=_key, local_only=True
)
)
# decrement
await self.user_api_key_cache.async_increment_cache(
key=_key, value=-1, local_only=True
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")