From f3e47323b98de100922bfc19fe88e96de5e9d6e0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 1 Apr 2024 07:59:30 -0700 Subject: [PATCH] test(test_max_tpm_rpm_limiter.py): unit tests for key + team based tpm rpm limits on proxy --- litellm/proxy/hooks/tpm_rpm_limiter.py | 11 +- litellm/tests/test_max_tpm_rpm_limiter.py | 122 ++++++++++++++++++++++ 2 files changed, 128 insertions(+), 5 deletions(-) create mode 100644 litellm/tests/test_max_tpm_rpm_limiter.py diff --git a/litellm/proxy/hooks/tpm_rpm_limiter.py b/litellm/proxy/hooks/tpm_rpm_limiter.py index db1d1759a..a46337491 100644 --- a/litellm/proxy/hooks/tpm_rpm_limiter.py +++ b/litellm/proxy/hooks/tpm_rpm_limiter.py @@ -102,6 +102,7 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): request_count_api_key: str, type: Literal["key", "user", "team"], ): + if type == "key" and user_api_key_dict.api_key is not None: current = current_minute_dict["key"].get(user_api_key_dict.api_key, None) elif type == "user" and user_api_key_dict.user_id is not None: @@ -110,7 +111,6 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): current = current_minute_dict["team"].get(user_api_key_dict.team_id, None) else: return - if current is None: if tpm_limit == 0 or rpm_limit == 0: # base case @@ -138,10 +138,14 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): ## get team tpm/rpm limits team_id = user_api_key_dict.team_id + self.user_api_key_cache = cache + _set_limits = self._check_limits_set( user_api_key_cache=cache, key=api_key, user_id=user_id, team_id=team_id ) + self.print_verbose(f"_set_limits: {_set_limits}") + if _set_limits == False: return @@ -149,8 +153,6 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): # Setup values # ------------ - self.user_api_key_cache = cache - current_date = datetime.now().strftime("%Y-%m-%d") current_hour = datetime.now().strftime("%H") current_minute = datetime.now().strftime("%M") @@ -247,7 +249,6 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_team_id", None ) - _limits_set = self._check_limits_set( user_api_key_cache=self.user_api_key_cache, key=user_api_key, @@ -377,4 +378,4 @@ class _PROXY_MaxTPMRPMLimiter(CustomLogger): ) except Exception as e: - self.print_verbose(e) # noqa + self.print_verbose("{}\n{}".format(e, traceback.format_exc())) # noqa diff --git a/litellm/tests/test_max_tpm_rpm_limiter.py b/litellm/tests/test_max_tpm_rpm_limiter.py new file mode 100644 index 000000000..db1ab0f86 --- /dev/null +++ b/litellm/tests/test_max_tpm_rpm_limiter.py @@ -0,0 +1,122 @@ +# What is this? +## Unit tests for the max tpm / rpm limiter hook for proxy + +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import Router +from litellm.proxy.utils import ProxyLogging +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching import DualCache, RedisCache +from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter +from datetime import datetime + + +@pytest.mark.asyncio +async def test_pre_call_hook_rpm_limits(): + """ + Test if error raised on hitting rpm limits + """ + litellm.set_verbose = True + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1) + local_cache = DualCache() + # redis_usage_cache = RedisCache() + + local_cache.set_cache( + key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1} + ) + + tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=None) + + await tpm_rpm_limiter.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}} + + await tpm_rpm_limiter.async_log_success_event( + kwargs=kwargs, + response_obj="", + start_time="", + end_time="", + ) + + ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} + + try: + await tpm_rpm_limiter.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={}, + call_type="", + ) + + pytest.fail(f"Expected call to fail") + except Exception as e: + assert e.status_code == 429 + + +@pytest.mark.asyncio +async def test_pre_call_hook_team_rpm_limits(): + """ + Test if error raised on hitting team rpm limits + """ + litellm.set_verbose = True + _api_key = "sk-12345" + _team_id = "unique-team-id" + _user_api_key_dict = { + "api_key": _api_key, + "max_parallel_requests": 1, + "tpm_limit": 9, + "rpm_limit": 10, + "team_rpm_limit": 1, + "team_id": _team_id, + } + user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) + local_cache = DualCache() + local_cache.set_cache(key=_api_key, value=_user_api_key_dict) + tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=None) + + await tpm_rpm_limiter.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + kwargs = { + "litellm_params": { + "metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id} + } + } + + await tpm_rpm_limiter.async_log_success_event( + kwargs=kwargs, + response_obj="", + start_time="", + end_time="", + ) + + print(f"local_cache: {local_cache}") + + ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} + + try: + await tpm_rpm_limiter.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={}, + call_type="", + ) + + pytest.fail(f"Expected call to fail") + except Exception as e: + assert e.status_code == 429