forked from phoenix/litellm-mirror
test(test_max_tpm_rpm_limiter.py): unit tests for key + team based tpm rpm limits on proxy
This commit is contained in:
parent
19fc120081
commit
f3e47323b9
2 changed files with 128 additions and 5 deletions
|
@ -102,6 +102,7 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger):
|
||||||
request_count_api_key: str,
|
request_count_api_key: str,
|
||||||
type: Literal["key", "user", "team"],
|
type: Literal["key", "user", "team"],
|
||||||
):
|
):
|
||||||
|
|
||||||
if type == "key" and user_api_key_dict.api_key is not None:
|
if type == "key" and user_api_key_dict.api_key is not None:
|
||||||
current = current_minute_dict["key"].get(user_api_key_dict.api_key, None)
|
current = current_minute_dict["key"].get(user_api_key_dict.api_key, None)
|
||||||
elif type == "user" and user_api_key_dict.user_id is not None:
|
elif type == "user" and user_api_key_dict.user_id is not None:
|
||||||
|
@ -110,7 +111,6 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger):
|
||||||
current = current_minute_dict["team"].get(user_api_key_dict.team_id, None)
|
current = current_minute_dict["team"].get(user_api_key_dict.team_id, None)
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
if current is None:
|
if current is None:
|
||||||
if tpm_limit == 0 or rpm_limit == 0:
|
if tpm_limit == 0 or rpm_limit == 0:
|
||||||
# base case
|
# base case
|
||||||
|
@ -138,10 +138,14 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger):
|
||||||
## get team tpm/rpm limits
|
## get team tpm/rpm limits
|
||||||
team_id = user_api_key_dict.team_id
|
team_id = user_api_key_dict.team_id
|
||||||
|
|
||||||
|
self.user_api_key_cache = cache
|
||||||
|
|
||||||
_set_limits = self._check_limits_set(
|
_set_limits = self._check_limits_set(
|
||||||
user_api_key_cache=cache, key=api_key, user_id=user_id, team_id=team_id
|
user_api_key_cache=cache, key=api_key, user_id=user_id, team_id=team_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.print_verbose(f"_set_limits: {_set_limits}")
|
||||||
|
|
||||||
if _set_limits == False:
|
if _set_limits == False:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -149,8 +153,6 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger):
|
||||||
# Setup values
|
# Setup values
|
||||||
# ------------
|
# ------------
|
||||||
|
|
||||||
self.user_api_key_cache = cache
|
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
current_hour = datetime.now().strftime("%H")
|
current_hour = datetime.now().strftime("%H")
|
||||||
current_minute = datetime.now().strftime("%M")
|
current_minute = datetime.now().strftime("%M")
|
||||||
|
@ -247,7 +249,6 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger):
|
||||||
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
|
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
"user_api_key_team_id", None
|
"user_api_key_team_id", None
|
||||||
)
|
)
|
||||||
|
|
||||||
_limits_set = self._check_limits_set(
|
_limits_set = self._check_limits_set(
|
||||||
user_api_key_cache=self.user_api_key_cache,
|
user_api_key_cache=self.user_api_key_cache,
|
||||||
key=user_api_key,
|
key=user_api_key,
|
||||||
|
@ -377,4 +378,4 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.print_verbose(e) # noqa
|
self.print_verbose("{}\n{}".format(e, traceback.format_exc())) # noqa
|
||||||
|
|
122
litellm/tests/test_max_tpm_rpm_limiter.py
Normal file
122
litellm/tests/test_max_tpm_rpm_limiter.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit tests for the max tpm / rpm limiter hook for proxy
|
||||||
|
|
||||||
|
import sys, os, asyncio, time, random
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.caching import DualCache, RedisCache
|
||||||
|
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_call_hook_rpm_limits():
|
||||||
|
"""
|
||||||
|
Test if error raised on hitting rpm limits
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1)
|
||||||
|
local_cache = DualCache()
|
||||||
|
# redis_usage_cache = RedisCache()
|
||||||
|
|
||||||
|
local_cache.set_cache(
|
||||||
|
key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1}
|
||||||
|
)
|
||||||
|
|
||||||
|
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=None)
|
||||||
|
|
||||||
|
await tpm_rpm_limiter.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}}
|
||||||
|
|
||||||
|
await tpm_rpm_limiter.async_log_success_event(
|
||||||
|
kwargs=kwargs,
|
||||||
|
response_obj="",
|
||||||
|
start_time="",
|
||||||
|
end_time="",
|
||||||
|
)
|
||||||
|
|
||||||
|
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await tpm_rpm_limiter.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
|
||||||
|
pytest.fail(f"Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
assert e.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_call_hook_team_rpm_limits():
|
||||||
|
"""
|
||||||
|
Test if error raised on hitting team rpm limits
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
_team_id = "unique-team-id"
|
||||||
|
_user_api_key_dict = {
|
||||||
|
"api_key": _api_key,
|
||||||
|
"max_parallel_requests": 1,
|
||||||
|
"tpm_limit": 9,
|
||||||
|
"rpm_limit": 10,
|
||||||
|
"team_rpm_limit": 1,
|
||||||
|
"team_id": _team_id,
|
||||||
|
}
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict)
|
||||||
|
local_cache = DualCache()
|
||||||
|
local_cache.set_cache(key=_api_key, value=_user_api_key_dict)
|
||||||
|
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=None)
|
||||||
|
|
||||||
|
await tpm_rpm_limiter.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await tpm_rpm_limiter.async_log_success_event(
|
||||||
|
kwargs=kwargs,
|
||||||
|
response_obj="",
|
||||||
|
start_time="",
|
||||||
|
end_time="",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"local_cache: {local_cache}")
|
||||||
|
|
||||||
|
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await tpm_rpm_limiter.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
|
||||||
|
pytest.fail(f"Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
assert e.status_code == 429
|
Loading…
Add table
Add a link
Reference in a new issue