fix(parallel_request_limiter.py): fix team rate limit enforcement

This commit is contained in:
Krrish Dholakia 2024-02-26 18:06:13 -08:00
parent f84ac35000
commit f86ab19067
3 changed files with 105 additions and 56 deletions

View file

@ -38,7 +38,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
current = cache.get_cache(
key=request_count_api_key
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
# print(f"current: {current}")
if current is None:
new_val = {
"current_requests": 1,
@ -73,8 +72,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
if api_key is None:
return
@ -131,10 +130,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
# get user tpm/rpm limits
if _user_id_rate_limits is None or _user_id_rate_limits == {}:
return
user_tpm_limit = _user_id_rate_limits.get("tpm_limit")
user_rpm_limit = _user_id_rate_limits.get("rpm_limit")
if _user_id_rate_limits is not None and isinstance(_user_id_rate_limits, dict):
user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
if user_tpm_limit is None:
user_tpm_limit = sys.maxsize
if user_rpm_limit is None:
@ -158,8 +156,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# TEAM RATE LIMITS
## get team tpm/rpm limits
team_id = user_api_key_dict.team_id
team_tpm_limit = user_api_key_dict.team_tpm_limit or sys.maxsize
team_rpm_limit = user_api_key_dict.team_rpm_limit or sys.maxsize
team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize)
team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize)
if team_tpm_limit is None:
team_tpm_limit = sys.maxsize
@ -241,9 +239,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# ------------
# Update usage - User
# ------------
if user_api_key_user_id is None:
return
if user_api_key_user_id is not None:
total_tokens = 0
if isinstance(response_obj, ModelResponse):
@ -253,7 +249,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
f"{user_api_key_user_id}::{precise_minute}::request_count"
)
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or {
current = self.user_api_key_cache.get_cache(
key=request_count_api_key
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,

View file

@ -350,8 +350,7 @@ async def user_api_key_auth(
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
if api_key.startswith("sk-"):
api_key = hash_token(token=api_key)
# valid_token = user_api_key_cache.get_cache(key=api_key)
valid_token = None
valid_token = user_api_key_cache.get_cache(key=api_key)
if valid_token is None:
## check db
verbose_proxy_logger.debug(f"api key: {api_key}")
@ -384,7 +383,6 @@ async def user_api_key_auth(
# 6. If token spend per model is under budget per model
# 7. If token spend is under team budget
# 8. If team spend is under team budget
request_data = await _read_request_body(
request=request
) # request data, used across all checks. Making this easily available
@ -627,7 +625,7 @@ async def user_api_key_auth(
)
)
if valid_token.spend > valid_token.team_max_budget:
if valid_token.spend >= valid_token.team_max_budget:
raise Exception(
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}"
)
@ -646,7 +644,7 @@ async def user_api_key_auth(
)
)
if valid_token.team_spend > valid_token.team_max_budget:
if valid_token.team_spend >= valid_token.team_max_budget:
raise Exception(
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
)

View file

@ -99,6 +99,59 @@ async def test_pre_call_hook_rpm_limits():
assert e.status_code == 429
@pytest.mark.asyncio
async def test_pre_call_hook_team_rpm_limits():
"""
Test if error raised on hitting team rpm limits
"""
litellm.set_verbose = True
_api_key = "sk-12345"
_team_id = "unique-team-id"
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
max_parallel_requests=1,
tpm_limit=9,
rpm_limit=10,
team_rpm_limit=1,
team_id=_team_id,
)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
kwargs = {
"litellm_params": {
"metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id}
}
}
await parallel_request_handler.async_log_success_event(
kwargs=kwargs,
response_obj="",
start_time="",
end_time="",
)
print(f"local_cache: {local_cache}")
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
try:
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={},
call_type="",
)
pytest.fail(f"Expected call to fail")
except Exception as e:
assert e.status_code == 429
@pytest.mark.asyncio
async def test_pre_call_hook_tpm_limits():
"""