diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 2b2054756..75acd87ce 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -68,6 +68,7 @@ router_settings: litellm_settings: success_callback: ["langfuse"] failure_callback: ["langfuse"] + cache: true # general_settings: # alerting: ["email"] diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 4ba7a2229..a17fcb2c9 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -10,18 +10,17 @@ from datetime import datetime class _PROXY_MaxParallelRequestsHandler(CustomLogger): - user_api_key_cache = None # Class variables or attributes - def __init__(self): - pass + def __init__(self, internal_usage_cache: DualCache): + self.internal_usage_cache = internal_usage_cache def print_verbose(self, print_statement): try: verbose_proxy_logger.debug(print_statement) if litellm.set_verbose: print(print_statement) # noqa - except: + except Exception: pass async def check_key_in_limits( @@ -35,7 +34,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): rpm_limit: int, request_count_api_key: str, ): - current = cache.get_cache( + current = await self.internal_usage_cache.async_get_cache( key=request_count_api_key ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} if current is None: @@ -49,7 +48,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_tpm": 0, "current_rpm": 0, } - cache.set_cache(request_count_api_key, new_val) + await self.internal_usage_cache.async_set_cache( + request_count_api_key, new_val + ) elif ( int(current["current_requests"]) < max_parallel_requests and current["current_tpm"] < tpm_limit @@ -61,7 +62,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_tpm": current["current_tpm"], "current_rpm": current["current_rpm"], } - cache.set_cache(request_count_api_key, new_val) + await self.internal_usage_cache.async_set_cache( + request_count_api_key, new_val + ) else: raise HTTPException( status_code=429, @@ -75,7 +78,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): data: dict, call_type: str, ): - self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") + self.print_verbose("Inside Max Parallel Request Pre-Call Hook") api_key = user_api_key_dict.api_key max_parallel_requests = user_api_key_dict.max_parallel_requests if max_parallel_requests is None: @@ -90,7 +93,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): if rpm_limit is None: rpm_limit = sys.maxsize - self.user_api_key_cache = cache # save the api key cache for updating the value # ------------ # Setup values # ------------ @@ -98,7 +100,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): if global_max_parallel_requests is not None: # get value from cache _key = "global_max_parallel_requests" - current_global_requests = await cache.async_get_cache( + current_global_requests = await self.internal_usage_cache.async_get_cache( key=_key, local_only=True ) # check if below limit @@ -111,7 +113,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ) # if below -> increment else: - await cache.async_increment_cache(key=_key, value=1, local_only=True) + await self.internal_usage_cache.async_increment_cache( + key=_key, value=1, local_only=True + ) current_date = datetime.now().strftime("%Y-%m-%d") current_hour = datetime.now().strftime("%H") @@ -123,7 +127,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # CHECK IF REQUEST ALLOWED for key - current = cache.get_cache( + current = await self.internal_usage_cache.async_get_cache( key=request_count_api_key ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} self.print_verbose(f"current: {current}") @@ -143,7 +147,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_tpm": 0, "current_rpm": 0, } - cache.set_cache(request_count_api_key, new_val) + await self.internal_usage_cache.async_set_cache( + request_count_api_key, new_val + ) elif ( int(current["current_requests"]) < max_parallel_requests and current["current_tpm"] < tpm_limit @@ -155,7 +161,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "current_tpm": current["current_tpm"], "current_rpm": current["current_rpm"], } - cache.set_cache(request_count_api_key, new_val) + await self.internal_usage_cache.async_set_cache( + request_count_api_key, new_val + ) else: raise HTTPException( status_code=429, detail="Max parallel request limit reached." @@ -164,7 +172,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # check if REQUEST ALLOWED for user_id user_id = user_api_key_dict.user_id if user_id is not None: - _user_id_rate_limits = await self.user_api_key_cache.async_get_cache( + _user_id_rate_limits = await self.internal_usage_cache.async_get_cache( key=user_id ) # get user tpm/rpm limits @@ -256,7 +264,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: - self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING") + self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING") global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get( "global_max_parallel_requests", None ) @@ -269,9 +277,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ) user_api_key_end_user_id = kwargs.get("user") - if self.user_api_key_cache is None: - return - # ------------ # Setup values # ------------ @@ -280,7 +285,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # get value from cache _key = "global_max_parallel_requests" # decrement - await self.user_api_key_cache.async_increment_cache( + await self.internal_usage_cache.async_increment_cache( key=_key, value=-1, local_only=True ) @@ -303,7 +308,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): f"{user_api_key}::{precise_minute}::request_count" ) - current = self.user_api_key_cache.get_cache( + current = await self.internal_usage_cache.async_get_cache( key=request_count_api_key ) or { "current_requests": 1, @@ -320,7 +325,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): self.print_verbose( f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" ) - self.user_api_key_cache.set_cache( + await self.internal_usage_cache.async_set_cache( request_count_api_key, new_val, ttl=60 ) # store in cache for 1 min. @@ -337,7 +342,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): f"{user_api_key_user_id}::{precise_minute}::request_count" ) - current = self.user_api_key_cache.get_cache( + current = await self.internal_usage_cache.async_get_cache( key=request_count_api_key ) or { "current_requests": 1, @@ -354,7 +359,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): self.print_verbose( f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" ) - self.user_api_key_cache.set_cache( + await self.internal_usage_cache.async_set_cache( request_count_api_key, new_val, ttl=60 ) # store in cache for 1 min. @@ -371,7 +376,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): f"{user_api_key_team_id}::{precise_minute}::request_count" ) - current = self.user_api_key_cache.get_cache( + current = await self.internal_usage_cache.async_get_cache( key=request_count_api_key ) or { "current_requests": 1, @@ -388,7 +393,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): self.print_verbose( f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" ) - self.user_api_key_cache.set_cache( + await self.internal_usage_cache.async_set_cache( request_count_api_key, new_val, ttl=60 ) # store in cache for 1 min. @@ -405,7 +410,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): f"{user_api_key_end_user_id}::{precise_minute}::request_count" ) - current = self.user_api_key_cache.get_cache( + current = await self.internal_usage_cache.async_get_cache( key=request_count_api_key ) or { "current_requests": 1, @@ -422,7 +427,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): self.print_verbose( f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" ) - self.user_api_key_cache.set_cache( + await self.internal_usage_cache.async_set_cache( request_count_api_key, new_val, ttl=60 ) # store in cache for 1 min. @@ -442,9 +447,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): if user_api_key is None: return - if self.user_api_key_cache is None: - return - ## decrement call count if call failed if "Max parallel request limit reached" in str(kwargs["exception"]): pass # ignore failed calls due to max limit being reached @@ -457,12 +459,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # get value from cache _key = "global_max_parallel_requests" current_global_requests = ( - await self.user_api_key_cache.async_get_cache( + await self.internal_usage_cache.async_get_cache( key=_key, local_only=True ) ) # decrement - await self.user_api_key_cache.async_increment_cache( + await self.internal_usage_cache.async_increment_cache( key=_key, value=-1, local_only=True ) @@ -478,7 +480,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # ------------ # Update usage # ------------ - current = self.user_api_key_cache.get_cache( + current = await self.internal_usage_cache.async_get_cache( key=request_count_api_key ) or { "current_requests": 1, @@ -493,7 +495,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): } self.print_verbose(f"updated_value in failure call: {new_val}") - self.user_api_key_cache.set_cache( + await self.internal_usage_cache.async_set_cache( request_count_api_key, new_val, ttl=60 ) # save in cache for up to 1 min. except Exception as e: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d1591d188..8fadfc728 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2852,6 +2852,7 @@ class ProxyConfig: use_azure_key_vault = general_settings.get("use_azure_key_vault", False) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) ### ALERTING ### + proxy_logging_obj.update_values( alerting=general_settings.get("alerting", None), alerting_threshold=general_settings.get("alerting_threshold", 600), @@ -3963,6 +3964,11 @@ async def startup_event(): db_writer_client = HTTPHandler() + ## UPDATE INTERNAL USAGE CACHE ## + proxy_logging_obj.update_values( + redis_cache=redis_usage_cache + ) # used by parallel request limiter for rate limiting keys across instances + proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 54782c088..4ac333bf4 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -162,8 +162,12 @@ class ProxyLogging: ## INITIALIZE LITELLM CALLBACKS ## self.call_details: dict = {} self.call_details["user_api_key_cache"] = user_api_key_cache - self.internal_usage_cache = DualCache() - self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler() + self.internal_usage_cache = DualCache( + default_in_memory_ttl=1 + ) # ping redis cache every 1s + self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler( + self.internal_usage_cache + ) self.max_budget_limiter = _PROXY_MaxBudgetLimiter() self.cache_control_check = _PROXY_CacheControlCheck() self.alerting: Optional[List] = None @@ -189,39 +193,45 @@ class ProxyLogging: def update_values( self, - alerting: Optional[List], - alerting_threshold: Optional[float], - redis_cache: Optional[RedisCache], + alerting: Optional[List] = None, + alerting_threshold: Optional[float] = None, + redis_cache: Optional[RedisCache] = None, alert_types: Optional[List[AlertType]] = None, alerting_args: Optional[dict] = None, ): - self.alerting = alerting + updated_slack_alerting: bool = False + if self.alerting is not None: + self.alerting = alerting + updated_slack_alerting = True if alerting_threshold is not None: self.alerting_threshold = alerting_threshold + updated_slack_alerting = True if alert_types is not None: self.alert_types = alert_types + updated_slack_alerting = True - self.slack_alerting_instance.update_values( - alerting=self.alerting, - alerting_threshold=self.alerting_threshold, - alert_types=self.alert_types, - alerting_args=alerting_args, - ) + if updated_slack_alerting is True: + self.slack_alerting_instance.update_values( + alerting=self.alerting, + alerting_threshold=self.alerting_threshold, + alert_types=self.alert_types, + alerting_args=alerting_args, + ) - if ( - self.alerting is not None - and "slack" in self.alerting - and "daily_reports" in self.alert_types - ): - # NOTE: ENSURE we only add callbacks when alerting is on - # We should NOT add callbacks when alerting is off - litellm.callbacks.append(self.slack_alerting_instance) # type: ignore + if ( + self.alerting is not None + and "slack" in self.alerting + and "daily_reports" in self.alert_types + ): + # NOTE: ENSURE we only add callbacks when alerting is on + # We should NOT add callbacks when alerting is off + litellm.callbacks.append(self.slack_alerting_instance) # type: ignore if redis_cache is not None: self.internal_usage_cache.redis_cache = redis_cache def _init_litellm_callbacks(self): - print_verbose(f"INITIALIZING LITELLM CALLBACKS!") + print_verbose("INITIALIZING LITELLM CALLBACKS!") self.service_logging_obj = ServiceLogging() litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_budget_limiter)