# What is this? ## Unit tests for the max tpm / rpm limiter hook for proxy import sys, os, asyncio, time, random from datetime import datetime import traceback from dotenv import load_dotenv from typing import Optional load_dotenv() import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest import litellm from litellm import Router from litellm.proxy.utils import ProxyLogging, hash_token from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache, RedisCache from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter from datetime import datetime @pytest.mark.asyncio async def test_pre_call_hook_rpm_limits(): """ Test if error raised on hitting rpm limits """ litellm.set_verbose = True _api_key = hash_token("sk-12345") user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1) local_cache = DualCache() # redis_usage_cache = RedisCache() local_cache.set_cache( key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1} ) tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=DualCache()) await tpm_rpm_limiter.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" ) kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}} await tpm_rpm_limiter.async_log_success_event( kwargs=kwargs, response_obj="", start_time="", end_time="", ) ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} try: await tpm_rpm_limiter.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="", ) pytest.fail(f"Expected call to fail") except Exception as e: assert e.status_code == 429 @pytest.mark.asyncio async def test_pre_call_hook_team_rpm_limits( _redis_usage_cache: Optional[RedisCache] = None, ): """ Test if error raised on hitting team rpm limits """ litellm.set_verbose = True _api_key = "sk-12345" _team_id = "unique-team-id" _user_api_key_dict = { "api_key": _api_key, "max_parallel_requests": 1, "tpm_limit": 9, "rpm_limit": 10, "team_rpm_limit": 1, "team_id": _team_id, } user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore _api_key = hash_token(_api_key) local_cache = DualCache() local_cache.set_cache(key=_api_key, value=_user_api_key_dict) internal_cache = DualCache(redis_cache=_redis_usage_cache) tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=internal_cache) await tpm_rpm_limiter.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" ) kwargs = { "litellm_params": { "metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id} } } await tpm_rpm_limiter.async_log_success_event( kwargs=kwargs, response_obj="", start_time="", end_time="", ) print(f"local_cache: {local_cache}") ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} try: await tpm_rpm_limiter.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="", ) pytest.fail(f"Expected call to fail") except Exception as e: assert e.status_code == 429 # type: ignore @pytest.mark.asyncio async def test_namespace(): """ - test if default namespace set via `proxyconfig._init_cache` - respected for tpm/rpm caching """ from litellm.proxy.proxy_server import ProxyConfig redis_usage_cache: Optional[RedisCache] = None cache_params = {"type": "redis", "namespace": "litellm_default"} ## INIT CACHE ## proxy_config = ProxyConfig() setattr(litellm.proxy.proxy_server, "proxy_config", proxy_config) proxy_config._init_cache(cache_params=cache_params) redis_cache: Optional[RedisCache] = getattr( litellm.proxy.proxy_server, "redis_usage_cache" ) ## CHECK IF NAMESPACE SET ## assert redis_cache.namespace == "litellm_default" ## CHECK IF TPM/RPM RATE LIMITING WORKS ## await test_pre_call_hook_team_rpm_limits(_redis_usage_cache=redis_cache) current_date = datetime.now().strftime("%Y-%m-%d") current_hour = datetime.now().strftime("%H") current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" cache_key = "litellm_default:usage:{}".format(precise_minute) value = await redis_cache.async_get_cache(key=cache_key) assert value is not None