fix(tpm_rpm_limiter.py): enable redis caching for tpm/rpm checks on keys/user/teams

allows tpm/rpm checks to work across instances

https://github.com/BerriAI/litellm/issues/2730
This commit is contained in:
Krrish Dholakia 2024-03-30 20:01:36 -07:00
parent 0342cd3b6b
commit f58fefd589
5 changed files with 423 additions and 11 deletions

View file

@ -12,13 +12,14 @@ from litellm.proxy._types import (
LiteLLM_TeamTable,
Member,
)
from litellm.caching import DualCache
from litellm.caching import DualCache, RedisCache
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
@ -46,16 +47,21 @@ class ProxyLogging:
- support the max parallel request integration
"""
def __init__(self, user_api_key_cache: DualCache):
def __init__(
self, user_api_key_cache: DualCache, redis_usage_cache: Optional[RedisCache]
):
## INITIALIZE LITELLM CALLBACKS ##
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
# self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler()
self.max_tpm_rpm_limiter = _PROXY_MaxTPMRPMLimiter(
redis_usage_cache=redis_usage_cache
)
self.max_budget_limiter = _PROXY_MaxBudgetLimiter()
self.cache_control_check = _PROXY_CacheControlCheck()
self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold
pass
self.redis_usage_cache = redis_usage_cache
def update_values(
self, alerting: Optional[List], alerting_threshold: Optional[float]
@ -66,7 +72,8 @@ class ProxyLogging:
def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter)
# litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_tpm_rpm_limiter)
litellm.callbacks.append(self.max_budget_limiter)
litellm.callbacks.append(self.cache_control_check)
litellm.success_callback.append(self.response_taking_too_long_callback)