forked from phoenix/litellm-mirror
fix(parallel_request_limiter.py): fix team rate limit enforcement
This commit is contained in:
parent
f84ac35000
commit
f86ab19067
3 changed files with 105 additions and 56 deletions
|
@ -38,7 +38,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
current = cache.get_cache(
|
||||
key=request_count_api_key
|
||||
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||
# print(f"current: {current}")
|
||||
if current is None:
|
||||
new_val = {
|
||||
"current_requests": 1,
|
||||
|
@ -73,8 +72,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
|
||||
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
|
||||
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
|
||||
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
|
||||
|
||||
if api_key is None:
|
||||
return
|
||||
|
@ -131,10 +130,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
|
||||
|
||||
# get user tpm/rpm limits
|
||||
if _user_id_rate_limits is None or _user_id_rate_limits == {}:
|
||||
return
|
||||
user_tpm_limit = _user_id_rate_limits.get("tpm_limit")
|
||||
user_rpm_limit = _user_id_rate_limits.get("rpm_limit")
|
||||
if _user_id_rate_limits is not None and isinstance(_user_id_rate_limits, dict):
|
||||
user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
|
||||
user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
|
||||
if user_tpm_limit is None:
|
||||
user_tpm_limit = sys.maxsize
|
||||
if user_rpm_limit is None:
|
||||
|
@ -158,8 +156,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# TEAM RATE LIMITS
|
||||
## get team tpm/rpm limits
|
||||
team_id = user_api_key_dict.team_id
|
||||
team_tpm_limit = user_api_key_dict.team_tpm_limit or sys.maxsize
|
||||
team_rpm_limit = user_api_key_dict.team_rpm_limit or sys.maxsize
|
||||
team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize)
|
||||
team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize)
|
||||
|
||||
if team_tpm_limit is None:
|
||||
team_tpm_limit = sys.maxsize
|
||||
|
@ -241,9 +239,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# ------------
|
||||
# Update usage - User
|
||||
# ------------
|
||||
if user_api_key_user_id is None:
|
||||
return
|
||||
|
||||
if user_api_key_user_id is not None:
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
|
@ -253,7 +249,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
f"{user_api_key_user_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or {
|
||||
current = self.user_api_key_cache.get_cache(
|
||||
key=request_count_api_key
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": total_tokens,
|
||||
"current_rpm": 1,
|
||||
|
|
|
@ -350,8 +350,7 @@ async def user_api_key_auth(
|
|||
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
|
||||
if api_key.startswith("sk-"):
|
||||
api_key = hash_token(token=api_key)
|
||||
# valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
valid_token = None
|
||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
if valid_token is None:
|
||||
## check db
|
||||
verbose_proxy_logger.debug(f"api key: {api_key}")
|
||||
|
@ -384,7 +383,6 @@ async def user_api_key_auth(
|
|||
# 6. If token spend per model is under budget per model
|
||||
# 7. If token spend is under team budget
|
||||
# 8. If team spend is under team budget
|
||||
|
||||
request_data = await _read_request_body(
|
||||
request=request
|
||||
) # request data, used across all checks. Making this easily available
|
||||
|
@ -627,7 +625,7 @@ async def user_api_key_auth(
|
|||
)
|
||||
)
|
||||
|
||||
if valid_token.spend > valid_token.team_max_budget:
|
||||
if valid_token.spend >= valid_token.team_max_budget:
|
||||
raise Exception(
|
||||
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
||||
)
|
||||
|
@ -646,7 +644,7 @@ async def user_api_key_auth(
|
|||
)
|
||||
)
|
||||
|
||||
if valid_token.team_spend > valid_token.team_max_budget:
|
||||
if valid_token.team_spend >= valid_token.team_max_budget:
|
||||
raise Exception(
|
||||
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
||||
)
|
||||
|
|
|
@ -99,6 +99,59 @@ async def test_pre_call_hook_rpm_limits():
|
|||
assert e.status_code == 429
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook_team_rpm_limits():
|
||||
"""
|
||||
Test if error raised on hitting team rpm limits
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
_api_key = "sk-12345"
|
||||
_team_id = "unique-team-id"
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key,
|
||||
max_parallel_requests=1,
|
||||
tpm_limit=9,
|
||||
rpm_limit=10,
|
||||
team_rpm_limit=1,
|
||||
team_id=_team_id,
|
||||
)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = MaxParallelRequestsHandler()
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id}
|
||||
}
|
||||
}
|
||||
|
||||
await parallel_request_handler.async_log_success_event(
|
||||
kwargs=kwargs,
|
||||
response_obj="",
|
||||
start_time="",
|
||||
end_time="",
|
||||
)
|
||||
|
||||
print(f"local_cache: {local_cache}")
|
||||
|
||||
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||
|
||||
try:
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
pytest.fail(f"Expected call to fail")
|
||||
except Exception as e:
|
||||
assert e.status_code == 429
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook_tpm_limits():
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue