From c03b0bbb24cfbfc5105401c729fe20f465f8440c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 12 Apr 2024 18:25:14 -0700 Subject: [PATCH] fix(router.py): support pre_call_rpm_check for lowest_tpm_rpm_v2 routing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit have routing strategies expose an ‘update rpm’ function; for checking + updating rpm pre call --- litellm/caching.py | 23 ++++++-- litellm/router.py | 19 ++----- litellm/router_strategy/lowest_tpm_rpm_v2.py | 58 ++++++++++++++++++-- 3 files changed, 75 insertions(+), 25 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 2401d9708..cdb98d790 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -98,11 +98,12 @@ class InMemoryCache(BaseCache): return_val.append(val) return return_val - async def async_increment(self, key, value: int, **kwargs): + async def async_increment(self, key, value: int, **kwargs) -> int: # get the value init_value = await self.async_get_cache(key=key) or 0 value = init_value + value await self.async_set_cache(key, value, **kwargs) + return value def flush_cache(self): self.cache_dict.clear() @@ -266,11 +267,12 @@ class RedisCache(BaseCache): if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: await self.flush_cache_buffer() - async def async_increment(self, key, value: int, **kwargs): + async def async_increment(self, key, value: int, **kwargs) -> int: _redis_client = self.init_async_client() try: async with _redis_client as redis_client: - await redis_client.incr(name=key, amount=value) + result = await redis_client.incr(name=key, amount=value) + return result except Exception as e: verbose_logger.error( "LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s", @@ -278,6 +280,7 @@ class RedisCache(BaseCache): value, ) traceback.print_exc() + raise e async def flush_cache_buffer(self): print_verbose( @@ -1076,21 +1079,29 @@ class DualCache(BaseCache): async def async_increment_cache( self, key, value: int, local_only: bool = False, **kwargs - ): + ) -> int: """ Key - the key in cache Value - int - the value you want to increment by + + Returns - int - the incremented value """ try: + result: int = value if self.in_memory_cache is not None: - await self.in_memory_cache.async_increment(key, value, **kwargs) + result = await self.in_memory_cache.async_increment( + key, value, **kwargs + ) if self.redis_cache is not None and local_only == False: - await self.redis_cache.async_increment(key, value, **kwargs) + result = await self.redis_cache.async_increment(key, value, **kwargs) + + return result except Exception as e: print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") traceback.print_exc() + raise e def flush_cache(self): if self.in_memory_cache is not None: diff --git a/litellm/router.py b/litellm/router.py index c127c3f8b..9abf70956 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -491,23 +491,16 @@ class Router: deployment=deployment, kwargs=kwargs, client_type="rpm_client" ) - if rpm_semaphore is not None and isinstance( - rpm_semaphore, asyncio.Semaphore + if ( + rpm_semaphore is not None + and isinstance(rpm_semaphore, asyncio.Semaphore) + and self.routing_strategy == "usage-based-routing-v2" ): async with rpm_semaphore: """ - - Check against in-memory tpm/rpm limits before making the call + - Check rpm limits before making the call """ - dt = get_utc_datetime() - current_minute = dt.strftime("%H-%M") - id = kwargs["model_info"]["id"] - rpm_key = "{}:rpm:{}".format(id, current_minute) - curr_rpm = await self.cache.async_get_cache(key=rpm_key) - if ( - curr_rpm is not None and curr_rpm >= data["rpm"] - ): # >= b/c the initial count is 0 - raise Exception("Rate Limit error") - await self.cache.async_increment_cache(key=rpm_key, value=1) + await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment) response = await _response else: response = await _response diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index c5598c11e..305f564aa 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -7,7 +7,8 @@ import datetime as datetime_og from datetime import datetime dotenv.load_dotenv() # Loading env variables using dotenv -import traceback, asyncio +import traceback, asyncio, httpx +import litellm from litellm import token_counter from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger @@ -37,6 +38,55 @@ class LowestTPMLoggingHandler_v2(CustomLogger): self.router_cache = router_cache self.model_list = model_list + async def pre_call_rpm_check(self, deployment: dict) -> dict: + """ + Pre-call check + update model rpm + - Used inside semaphore + - raise rate limit error if deployment over limit + + Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994 + + Returns - deployment + + Raises - RateLimitError if deployment over defined RPM limit + """ + + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + model_group = deployment.get("model_name", "") + rpm_key = f"{model_group}:rpm:{current_minute}" + result = await self.router_cache.async_increment_cache(key=rpm_key, value=1) + + deployment_rpm = None + if deployment_rpm is None: + deployment_rpm = deployment.get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("litellm_params", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("model_info", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = float("inf") + + if result is not None and result > deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, result + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return deployment + def log_success_event(self, kwargs, response_obj, start_time, end_time): try: """ @@ -91,7 +141,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: """ - Update TPM/RPM usage on success + Update TPM usage on success """ if kwargs["litellm_params"].get("metadata") is None: pass @@ -117,8 +167,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger): ) # use the same timezone regardless of system clock tpm_key = f"{id}:tpm:{current_minute}" - rpm_key = f"{id}:rpm:{current_minute}" - # ------------ # Update usage # ------------ @@ -128,8 +176,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger): await self.router_cache.async_increment_cache( key=tpm_key, value=total_tokens ) - ## RPM - await self.router_cache.async_increment_cache(key=rpm_key, value=1) ### TESTING ### if self.test_flag: