diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f819791b7..383e4c2c7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1908,6 +1908,19 @@ class ProxyConfig: team_config[k] = litellm.get_secret(v) return team_config + def _init_cache( + self, + cache_params: dict, + ): + global redis_usage_cache + from litellm import Cache + + litellm.cache = Cache(**cache_params) + + if litellm.cache is not None and isinstance(litellm.cache.cache, RedisCache): + ## INIT PROXY REDIS USAGE CLIENT ## + redis_usage_cache = litellm.cache.cache + async def load_config( self, router: Optional[litellm.Router], config_file_path: str ): @@ -2001,17 +2014,11 @@ class ProxyConfig: cache_params[key] = litellm.get_secret(value) ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables - - litellm.cache = Cache(**cache_params) - - if litellm.cache is not None and isinstance( - litellm.cache.cache, RedisCache - ): - ## INIT PROXY REDIS USAGE CLIENT ## - redis_usage_cache = litellm.cache.cache - print( # noqa - f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" - ) + self._init_cache(cache_params=cache_params) + if litellm.cache is not None: + print( # noqa + f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" + ) elif key == "cache" and value == False: pass elif key == "callbacks": diff --git a/litellm/tests/test_max_tpm_rpm_limiter.py b/litellm/tests/test_max_tpm_rpm_limiter.py index db1ab0f86..40a978c62 100644 --- a/litellm/tests/test_max_tpm_rpm_limiter.py +++ b/litellm/tests/test_max_tpm_rpm_limiter.py @@ -5,6 +5,7 @@ import sys, os, asyncio, time, random from datetime import datetime import traceback from dotenv import load_dotenv +from typing import Optional load_dotenv() import os @@ -68,7 +69,9 @@ async def test_pre_call_hook_rpm_limits(): @pytest.mark.asyncio -async def test_pre_call_hook_team_rpm_limits(): +async def test_pre_call_hook_team_rpm_limits( + _redis_usage_cache: Optional[RedisCache] = None, +): """ Test if error raised on hitting team rpm limits """ @@ -83,10 +86,10 @@ async def test_pre_call_hook_team_rpm_limits(): "team_rpm_limit": 1, "team_id": _team_id, } - user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) + user_api_key_dict = UserAPIKeyAuth(**_user_api_key_dict) # type: ignore local_cache = DualCache() local_cache.set_cache(key=_api_key, value=_user_api_key_dict) - tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=None) + tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(redis_usage_cache=_redis_usage_cache) await tpm_rpm_limiter.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" @@ -119,4 +122,40 @@ async def test_pre_call_hook_team_rpm_limits(): pytest.fail(f"Expected call to fail") except Exception as e: - assert e.status_code == 429 + assert e.status_code == 429 # type: ignore + + +@pytest.mark.asyncio +async def test_namespace(): + """ + - test if default namespace set via `proxyconfig._init_cache` + - respected for tpm/rpm caching + """ + from litellm.proxy.proxy_server import ProxyConfig + + redis_usage_cache: Optional[RedisCache] = None + cache_params = {"type": "redis", "namespace": "litellm_default"} + + ## INIT CACHE ## + proxy_config = ProxyConfig() + setattr(litellm.proxy.proxy_server, "proxy_config", proxy_config) + + proxy_config._init_cache(cache_params=cache_params) + + redis_cache: Optional[RedisCache] = getattr( + litellm.proxy.proxy_server, "redis_usage_cache" + ) + + ## CHECK IF NAMESPACE SET ## + assert redis_cache.namespace == "litellm_default" + + ## CHECK IF TPM/RPM RATE LIMITING WORKS ## + await test_pre_call_hook_team_rpm_limits(_redis_usage_cache=redis_cache) + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + cache_key = "litellm_default:usage:{}".format(precise_minute) + value = await redis_cache.async_get_cache(key=cache_key) + assert value is not None