test(test_parallel_request_limiter.py): unit testing for tpm/rpm rate limits

This commit is contained in:
Krrish Dholakia 2024-01-18 15:28:28 -08:00
parent 3957a8303a
commit 34c3b33b37
2 changed files with 17 additions and 5 deletions

View file

@ -1,5 +1,5 @@
from typing import Optional
import litellm, traceback
import litellm, traceback, sys
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
@ -28,12 +28,18 @@ class MaxParallelRequestsHandler(CustomLogger):
):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
tpm_limit = user_api_key_dict.tpm_limit
rpm_limit = user_api_key_dict.rpm_limit
if api_key is None:
return
if max_parallel_requests is None:
if (
max_parallel_requests == sys.maxsize
and tpm_limit == sys.maxsize
and rpm_limit == sys.maxsize
):
return
self.user_api_key_cache = cache # save the api key cache for updating the value
@ -60,7 +66,11 @@ class MaxParallelRequestsHandler(CustomLogger):
"current_rpm": 0,
}
cache.set_cache(request_count_api_key, new_val)
elif int(current["current_requests"]) < max_parallel_requests:
elif (
int(current["current_requests"]) < max_parallel_requests
and current["current_tpm"] < tpm_limit
and current["current_rpm"] < rpm_limit
):
# Increase count for this token
new_val = {
"current_requests": current["current_requests"] + 1,