forked from phoenix/litellm-mirror
fix(parallel_request_limiter.py): fix team rate limit enforcement
This commit is contained in:
parent
f84ac35000
commit
f86ab19067
3 changed files with 105 additions and 56 deletions
|
@ -38,7 +38,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
current = cache.get_cache(
|
current = cache.get_cache(
|
||||||
key=request_count_api_key
|
key=request_count_api_key
|
||||||
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||||
# print(f"current: {current}")
|
|
||||||
if current is None:
|
if current is None:
|
||||||
new_val = {
|
new_val = {
|
||||||
"current_requests": 1,
|
"current_requests": 1,
|
||||||
|
@ -73,8 +72,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||||
api_key = user_api_key_dict.api_key
|
api_key = user_api_key_dict.api_key
|
||||||
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
|
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
|
||||||
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
|
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||||
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
|
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
|
||||||
|
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
return
|
return
|
||||||
|
@ -131,35 +130,34 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
|
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
|
||||||
|
|
||||||
# get user tpm/rpm limits
|
# get user tpm/rpm limits
|
||||||
if _user_id_rate_limits is None or _user_id_rate_limits == {}:
|
if _user_id_rate_limits is not None and isinstance(_user_id_rate_limits, dict):
|
||||||
return
|
user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
|
||||||
user_tpm_limit = _user_id_rate_limits.get("tpm_limit")
|
user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
|
||||||
user_rpm_limit = _user_id_rate_limits.get("rpm_limit")
|
if user_tpm_limit is None:
|
||||||
if user_tpm_limit is None:
|
user_tpm_limit = sys.maxsize
|
||||||
user_tpm_limit = sys.maxsize
|
if user_rpm_limit is None:
|
||||||
if user_rpm_limit is None:
|
user_rpm_limit = sys.maxsize
|
||||||
user_rpm_limit = sys.maxsize
|
|
||||||
|
|
||||||
# now do the same tpm/rpm checks
|
# now do the same tpm/rpm checks
|
||||||
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||||
|
|
||||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||||
await self.check_key_in_limits(
|
await self.check_key_in_limits(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
data=data,
|
data=data,
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
|
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
|
||||||
request_count_api_key=request_count_api_key,
|
request_count_api_key=request_count_api_key,
|
||||||
tpm_limit=user_tpm_limit,
|
tpm_limit=user_tpm_limit,
|
||||||
rpm_limit=user_rpm_limit,
|
rpm_limit=user_rpm_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TEAM RATE LIMITS
|
# TEAM RATE LIMITS
|
||||||
## get team tpm/rpm limits
|
## get team tpm/rpm limits
|
||||||
team_id = user_api_key_dict.team_id
|
team_id = user_api_key_dict.team_id
|
||||||
team_tpm_limit = user_api_key_dict.team_tpm_limit or sys.maxsize
|
team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize)
|
||||||
team_rpm_limit = user_api_key_dict.team_rpm_limit or sys.maxsize
|
team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize)
|
||||||
|
|
||||||
if team_tpm_limit is None:
|
if team_tpm_limit is None:
|
||||||
team_tpm_limit = sys.maxsize
|
team_tpm_limit = sys.maxsize
|
||||||
|
@ -241,36 +239,36 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage - User
|
# Update usage - User
|
||||||
# ------------
|
# ------------
|
||||||
if user_api_key_user_id is None:
|
if user_api_key_user_id is not None:
|
||||||
return
|
total_tokens = 0
|
||||||
|
|
||||||
total_tokens = 0
|
if isinstance(response_obj, ModelResponse):
|
||||||
|
total_tokens = response_obj.usage.total_tokens
|
||||||
|
|
||||||
if isinstance(response_obj, ModelResponse):
|
request_count_api_key = (
|
||||||
total_tokens = response_obj.usage.total_tokens
|
f"{user_api_key_user_id}::{precise_minute}::request_count"
|
||||||
|
)
|
||||||
|
|
||||||
request_count_api_key = (
|
current = self.user_api_key_cache.get_cache(
|
||||||
f"{user_api_key_user_id}::{precise_minute}::request_count"
|
key=request_count_api_key
|
||||||
)
|
) or {
|
||||||
|
"current_requests": 1,
|
||||||
|
"current_tpm": total_tokens,
|
||||||
|
"current_rpm": 1,
|
||||||
|
}
|
||||||
|
|
||||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or {
|
new_val = {
|
||||||
"current_requests": 1,
|
"current_requests": max(current["current_requests"] - 1, 0),
|
||||||
"current_tpm": total_tokens,
|
"current_tpm": current["current_tpm"] + total_tokens,
|
||||||
"current_rpm": 1,
|
"current_rpm": current["current_rpm"] + 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
new_val = {
|
self.print_verbose(
|
||||||
"current_requests": max(current["current_requests"] - 1, 0),
|
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||||
"current_tpm": current["current_tpm"] + total_tokens,
|
)
|
||||||
"current_rpm": current["current_rpm"] + 1,
|
self.user_api_key_cache.set_cache(
|
||||||
}
|
request_count_api_key, new_val, ttl=60
|
||||||
|
) # store in cache for 1 min.
|
||||||
self.print_verbose(
|
|
||||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
|
||||||
)
|
|
||||||
self.user_api_key_cache.set_cache(
|
|
||||||
request_count_api_key, new_val, ttl=60
|
|
||||||
) # store in cache for 1 min.
|
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage - Team
|
# Update usage - Team
|
||||||
|
|
|
@ -350,8 +350,7 @@ async def user_api_key_auth(
|
||||||
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
|
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
|
||||||
if api_key.startswith("sk-"):
|
if api_key.startswith("sk-"):
|
||||||
api_key = hash_token(token=api_key)
|
api_key = hash_token(token=api_key)
|
||||||
# valid_token = user_api_key_cache.get_cache(key=api_key)
|
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||||
valid_token = None
|
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
## check db
|
## check db
|
||||||
verbose_proxy_logger.debug(f"api key: {api_key}")
|
verbose_proxy_logger.debug(f"api key: {api_key}")
|
||||||
|
@ -384,7 +383,6 @@ async def user_api_key_auth(
|
||||||
# 6. If token spend per model is under budget per model
|
# 6. If token spend per model is under budget per model
|
||||||
# 7. If token spend is under team budget
|
# 7. If token spend is under team budget
|
||||||
# 8. If team spend is under team budget
|
# 8. If team spend is under team budget
|
||||||
|
|
||||||
request_data = await _read_request_body(
|
request_data = await _read_request_body(
|
||||||
request=request
|
request=request
|
||||||
) # request data, used across all checks. Making this easily available
|
) # request data, used across all checks. Making this easily available
|
||||||
|
@ -627,7 +625,7 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if valid_token.spend > valid_token.team_max_budget:
|
if valid_token.spend >= valid_token.team_max_budget:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
||||||
)
|
)
|
||||||
|
@ -646,7 +644,7 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if valid_token.team_spend > valid_token.team_max_budget:
|
if valid_token.team_spend >= valid_token.team_max_budget:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -99,6 +99,59 @@ async def test_pre_call_hook_rpm_limits():
|
||||||
assert e.status_code == 429
|
assert e.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_call_hook_team_rpm_limits():
|
||||||
|
"""
|
||||||
|
Test if error raised on hitting team rpm limits
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
_team_id = "unique-team-id"
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key,
|
||||||
|
max_parallel_requests=1,
|
||||||
|
tpm_limit=9,
|
||||||
|
rpm_limit=10,
|
||||||
|
team_rpm_limit=1,
|
||||||
|
team_id=_team_id,
|
||||||
|
)
|
||||||
|
local_cache = DualCache()
|
||||||
|
parallel_request_handler = MaxParallelRequestsHandler()
|
||||||
|
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await parallel_request_handler.async_log_success_event(
|
||||||
|
kwargs=kwargs,
|
||||||
|
response_obj="",
|
||||||
|
start_time="",
|
||||||
|
end_time="",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"local_cache: {local_cache}")
|
||||||
|
|
||||||
|
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
|
||||||
|
pytest.fail(f"Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
assert e.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pre_call_hook_tpm_limits():
|
async def test_pre_call_hook_tpm_limits():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue