fix(parallel_request_limiter.py): fix team rate limit enforcement

This commit is contained in:
Krrish Dholakia 2024-02-26 18:06:13 -08:00
parent f84ac35000
commit f86ab19067
3 changed files with 105 additions and 56 deletions

View file

@ -38,7 +38,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
current = cache.get_cache( current = cache.get_cache(
key=request_count_api_key key=request_count_api_key
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
# print(f"current: {current}")
if current is None: if current is None:
new_val = { new_val = {
"current_requests": 1, "current_requests": 1,
@ -73,8 +72,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
if api_key is None: if api_key is None:
return return
@ -131,35 +130,34 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits _user_id_rate_limits = user_api_key_dict.user_id_rate_limits
# get user tpm/rpm limits # get user tpm/rpm limits
if _user_id_rate_limits is None or _user_id_rate_limits == {}: if _user_id_rate_limits is not None and isinstance(_user_id_rate_limits, dict):
return user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
user_tpm_limit = _user_id_rate_limits.get("tpm_limit") user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
user_rpm_limit = _user_id_rate_limits.get("rpm_limit") if user_tpm_limit is None:
if user_tpm_limit is None: user_tpm_limit = sys.maxsize
user_tpm_limit = sys.maxsize if user_rpm_limit is None:
if user_rpm_limit is None: user_rpm_limit = sys.maxsize
user_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks # now do the same tpm/rpm checks
request_count_api_key = f"{user_id}::{precise_minute}::request_count" request_count_api_key = f"{user_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits( await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
cache=cache, cache=cache,
data=data, data=data,
call_type=call_type, call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
request_count_api_key=request_count_api_key, request_count_api_key=request_count_api_key,
tpm_limit=user_tpm_limit, tpm_limit=user_tpm_limit,
rpm_limit=user_rpm_limit, rpm_limit=user_rpm_limit,
) )
# TEAM RATE LIMITS # TEAM RATE LIMITS
## get team tpm/rpm limits ## get team tpm/rpm limits
team_id = user_api_key_dict.team_id team_id = user_api_key_dict.team_id
team_tpm_limit = user_api_key_dict.team_tpm_limit or sys.maxsize team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize)
team_rpm_limit = user_api_key_dict.team_rpm_limit or sys.maxsize team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize)
if team_tpm_limit is None: if team_tpm_limit is None:
team_tpm_limit = sys.maxsize team_tpm_limit = sys.maxsize
@ -241,36 +239,36 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# ------------ # ------------
# Update usage - User # Update usage - User
# ------------ # ------------
if user_api_key_user_id is None: if user_api_key_user_id is not None:
return total_tokens = 0
total_tokens = 0 if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens
if isinstance(response_obj, ModelResponse): request_count_api_key = (
total_tokens = response_obj.usage.total_tokens f"{user_api_key_user_id}::{precise_minute}::request_count"
)
request_count_api_key = ( current = self.user_api_key_cache.get_cache(
f"{user_api_key_user_id}::{precise_minute}::request_count" key=request_count_api_key
) ) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or { new_val = {
"current_requests": 1, "current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": total_tokens, "current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": 1, "current_rpm": current["current_rpm"] + 1,
} }
new_val = { self.print_verbose(
"current_requests": max(current["current_requests"] - 1, 0), f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
"current_tpm": current["current_tpm"] + total_tokens, )
"current_rpm": current["current_rpm"] + 1, self.user_api_key_cache.set_cache(
} request_count_api_key, new_val, ttl=60
) # store in cache for 1 min.
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
self.user_api_key_cache.set_cache(
request_count_api_key, new_val, ttl=60
) # store in cache for 1 min.
# ------------ # ------------
# Update usage - Team # Update usage - Team

View file

@ -350,8 +350,7 @@ async def user_api_key_auth(
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility) original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
if api_key.startswith("sk-"): if api_key.startswith("sk-"):
api_key = hash_token(token=api_key) api_key = hash_token(token=api_key)
# valid_token = user_api_key_cache.get_cache(key=api_key) valid_token = user_api_key_cache.get_cache(key=api_key)
valid_token = None
if valid_token is None: if valid_token is None:
## check db ## check db
verbose_proxy_logger.debug(f"api key: {api_key}") verbose_proxy_logger.debug(f"api key: {api_key}")
@ -384,7 +383,6 @@ async def user_api_key_auth(
# 6. If token spend per model is under budget per model # 6. If token spend per model is under budget per model
# 7. If token spend is under team budget # 7. If token spend is under team budget
# 8. If team spend is under team budget # 8. If team spend is under team budget
request_data = await _read_request_body( request_data = await _read_request_body(
request=request request=request
) # request data, used across all checks. Making this easily available ) # request data, used across all checks. Making this easily available
@ -627,7 +625,7 @@ async def user_api_key_auth(
) )
) )
if valid_token.spend > valid_token.team_max_budget: if valid_token.spend >= valid_token.team_max_budget:
raise Exception( raise Exception(
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}" f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}"
) )
@ -646,7 +644,7 @@ async def user_api_key_auth(
) )
) )
if valid_token.team_spend > valid_token.team_max_budget: if valid_token.team_spend >= valid_token.team_max_budget:
raise Exception( raise Exception(
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}" f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
) )

View file

@ -99,6 +99,59 @@ async def test_pre_call_hook_rpm_limits():
assert e.status_code == 429 assert e.status_code == 429
@pytest.mark.asyncio
async def test_pre_call_hook_team_rpm_limits():
"""
Test if error raised on hitting team rpm limits
"""
litellm.set_verbose = True
_api_key = "sk-12345"
_team_id = "unique-team-id"
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
max_parallel_requests=1,
tpm_limit=9,
rpm_limit=10,
team_rpm_limit=1,
team_id=_team_id,
)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
kwargs = {
"litellm_params": {
"metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id}
}
}
await parallel_request_handler.async_log_success_event(
kwargs=kwargs,
response_obj="",
start_time="",
end_time="",
)
print(f"local_cache: {local_cache}")
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
try:
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={},
call_type="",
)
pytest.fail(f"Expected call to fail")
except Exception as e:
assert e.status_code == 429
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pre_call_hook_tpm_limits(): async def test_pre_call_hook_tpm_limits():
""" """