litellm-mirror/litellm/tests/test_max_tpm_rpm_limiter.py

162 lines
4.8 KiB
Python

# What is this?
## Unit tests for the max tpm / rpm limiter hook for proxy
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
from typing import Optional
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import Router
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache, RedisCache
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
from datetime import datetime
@pytest.mark.asyncio
async def test_pre_call_hook_rpm_limits():
"""
Test if error raised on hitting rpm limits
"""
litellm.set_verbose = True
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, tpm_limit=9, rpm_limit=1)
local_cache = DualCache()
# redis_usage_cache = RedisCache()
local_cache.set_cache(
key=_api_key, value={"api_key": _api_key, "tpm_limit": 9, "rpm_limit": 1}
)
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=DualCache())
await tpm_rpm_limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}}
await tpm_rpm_limiter.async_log_success_event(
kwargs=kwargs,
response_obj="",
start_time="",
end_time="",
)
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
try:
await tpm_rpm_limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={},
call_type="",
)
pytest.fail(f"Expected call to fail")
except Exception as e:
assert e.status_code == 429
@pytest.mark.asyncio
async def test_pre_call_hook_team_rpm_limits(
_redis_usage_cache: Optional[RedisCache] = None,
):
"""
Test if error raised on hitting team rpm limits
"""
litellm.set_verbose = True
_api_key = "sk-12345"
_team_id = "unique-team-id"
_user_api_key_dict = {
"api_key": _api_key,
"max_parallel_requests": 1,
"tpm_limit": 9,
"rpm_limit": 10,
"team_rpm_limit": 1,
"team_id": _team_id,
}
user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore
_api_key = hash_token(_api_key)
local_cache = DualCache()
local_cache.set_cache(key=_api_key, value=_user_api_key_dict)
internal_cache = DualCache(redis_cache=_redis_usage_cache)
tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(internal_cache=internal_cache)
await tpm_rpm_limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
kwargs = {
"litellm_params": {
"metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id}
}
}
await tpm_rpm_limiter.async_log_success_event(
kwargs=kwargs,
response_obj="",
start_time="",
end_time="",
)
print(f"local_cache: {local_cache}")
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
try:
await tpm_rpm_limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={},
call_type="",
)
pytest.fail(f"Expected call to fail")
except Exception as e:
assert e.status_code == 429 # type: ignore
@pytest.mark.asyncio
async def test_namespace():
"""
- test if default namespace set via `proxyconfig._init_cache`
- respected for tpm/rpm caching
"""
from litellm.proxy.proxy_server import ProxyConfig
redis_usage_cache: Optional[RedisCache] = None
cache_params = {"type": "redis", "namespace": "litellm_default"}
## INIT CACHE ##
proxy_config = ProxyConfig()
setattr(litellm.proxy.proxy_server, "proxy_config", proxy_config)
proxy_config._init_cache(cache_params=cache_params)
redis_cache: Optional[RedisCache] = getattr(
litellm.proxy.proxy_server, "redis_usage_cache"
)
## CHECK IF NAMESPACE SET ##
assert redis_cache.namespace == "litellm_default"
## CHECK IF TPM/RPM RATE LIMITING WORKS ##
await test_pre_call_hook_team_rpm_limits(_redis_usage_cache=redis_cache)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
cache_key = "litellm_default:usage:{}".format(precise_minute)
value = await redis_cache.async_get_cache(key=cache_key)
assert value is not None