diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index 0c6996b10..dc6f35642 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -21,7 +21,9 @@ class ServiceLogging(CustomLogger): if "prometheus_system" in litellm.service_callback: self.prometheusServicesLogger = PrometheusServicesLogger() - def service_success_hook(self, service: ServiceTypes, duration: float): + def service_success_hook( + self, service: ServiceTypes, duration: float, call_type: str + ): """ [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy). """ @@ -29,7 +31,7 @@ class ServiceLogging(CustomLogger): self.mock_testing_sync_success_hook += 1 def service_failure_hook( - self, service: ServiceTypes, duration: float, error: Exception + self, service: ServiceTypes, duration: float, error: Exception, call_type: str ): """ [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy). diff --git a/litellm/caching.py b/litellm/caching.py index c8ebd17df..f6d826de5 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -217,6 +217,7 @@ class RedisCache(BaseCache): self.service_logger_obj.service_success_hook( service=ServiceTypes.REDIS, duration=_duration, + call_type="increment_cache", ) ) return result @@ -226,11 +227,14 @@ class RedisCache(BaseCache): _duration = end_time - start_time asyncio.create_task( self.service_logger_obj.async_service_failure_hook( - service=ServiceTypes.REDIS, duration=_duration, error=e + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="increment_cache", ) ) verbose_logger.error( - "LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s", + "LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s", str(e), value, ) @@ -278,6 +282,9 @@ class RedisCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): start_time = time.time() + print_verbose( + f"Set Async Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}" + ) try: _redis_client = self.init_async_client() except Exception as e: @@ -341,6 +348,10 @@ class RedisCache(BaseCache): """ _redis_client = self.init_async_client() start_time = time.time() + + print_verbose( + f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}" + ) try: async with _redis_client as redis_client: async with redis_client.pipeline(transaction=True) as pipe: @@ -1261,7 +1272,6 @@ class DualCache(BaseCache): print_verbose(f"in_memory_result: {in_memory_result}") if in_memory_result is not None: result = in_memory_result - if None in result and self.redis_cache is not None and local_only == False: """ - for the none values in the result @@ -1277,14 +1287,12 @@ class DualCache(BaseCache): if redis_result is not None: # Update in-memory cache with the value from Redis - for key in redis_result: - await self.in_memory_cache.async_set_cache( - key, redis_result[key], **kwargs - ) - - sublist_dict = dict(zip(sublist_keys, redis_result)) - - for key, value in sublist_dict.items(): + for key, value in redis_result.items(): + if value is not None: + await self.in_memory_cache.async_set_cache( + key, redis_result[key], **kwargs + ) + for key, value in redis_result.items(): result[sublist_keys.index(key)] = value print_verbose(f"async batch get cache: cache result: {result}") @@ -1293,6 +1301,9 @@ class DualCache(BaseCache): traceback.print_exc() async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): + print_verbose( + f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}" + ) try: if self.in_memory_cache is not None: await self.in_memory_cache.async_set_cache(key, value, **kwargs) diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 503b3ff9d..b288036ad 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -6,7 +6,7 @@ import requests from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache -from typing import Literal, Union +from typing import Literal, Union, Optional dotenv.load_dotenv() # Loading env variables using dotenv import traceback @@ -46,6 +46,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): pass + #### PRE-CALL CHECKS - router/proxy only #### + """ + Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks). + """ + + async def async_pre_call_check(self, deployment: dict) -> Optional[dict]: + pass + + def pre_call_check(self, deployment: dict) -> Optional[dict]: + pass + #### CALL HOOKS - proxy only #### """ Control the modify incoming / outgoung data before calling the model diff --git a/litellm/router.py b/litellm/router.py index abbd6343b..fcb5424f6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -31,6 +31,7 @@ import copy from litellm._logging import verbose_router_logger import logging from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors +from litellm.integrations.custom_logger import CustomLogger class Router: @@ -492,18 +493,18 @@ class Router: deployment=deployment, kwargs=kwargs, client_type="rpm_client" ) - if ( - rpm_semaphore is not None - and isinstance(rpm_semaphore, asyncio.Semaphore) - and self.routing_strategy == "usage-based-routing-v2" + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore ): async with rpm_semaphore: """ - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ - await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment) + await self.routing_strategy_pre_call_checks(deployment=deployment) response = await _response else: + await self.routing_strategy_pre_call_checks(deployment=deployment) response = await _response self.success_calls[model_name] += 1 @@ -1712,6 +1713,22 @@ class Router: verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") return cooldown_models + async def routing_strategy_pre_call_checks(self, deployment: dict): + """ + For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore. + + -> makes the calls concurrency-safe, when rpm limits are set for a deployment + + Returns: + - None + + Raises: + - Rate Limit Exception - If the deployment is over it's tpm/rpm limits + """ + for _callback in litellm.callbacks: + if isinstance(_callback, CustomLogger): + response = await _callback.async_pre_call_check(deployment) + def set_client(self, model: dict): """ - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 @@ -2700,6 +2717,7 @@ class Router: verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" ) + return deployment def get_available_deployment( diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 6d7cc03ef..022bf5ffe 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -39,7 +39,81 @@ class LowestTPMLoggingHandler_v2(CustomLogger): self.router_cache = router_cache self.model_list = model_list - async def pre_call_rpm_check(self, deployment: dict) -> dict: + def pre_call_check(self, deployment: Dict) -> Dict | None: + """ + Pre-call check + update model rpm + + Returns - deployment + + Raises - RateLimitError if deployment over defined RPM limit + """ + try: + + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + model_id = deployment.get("model_info", {}).get("id") + rpm_key = f"{model_id}:rpm:{current_minute}" + local_result = self.router_cache.get_cache( + key=rpm_key, local_only=True + ) # check local result first + + deployment_rpm = None + if deployment_rpm is None: + deployment_rpm = deployment.get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("litellm_params", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("model_info", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = float("inf") + + if local_result is not None and local_result >= deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, local_result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + local_result, + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + else: + # if local result below limit, check redis ## prevent unnecessary redis checks + result = self.router_cache.increment_cache(key=rpm_key, value=1) + if result is not None and result > deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + result, + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return deployment + except Exception as e: + if isinstance(e, litellm.RateLimitError): + raise e + return deployment # don't fail calls if eg. redis fails to connect + + async def async_pre_call_check(self, deployment: Dict) -> Dict | None: """ Pre-call check + update model rpm - Used inside semaphore @@ -58,8 +132,8 @@ class LowestTPMLoggingHandler_v2(CustomLogger): # ------------ dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") - model_group = deployment.get("model_name", "") - rpm_key = f"{model_group}:rpm:{current_minute}" + model_id = deployment.get("model_info", {}).get("id") + rpm_key = f"{model_id}:rpm:{current_minute}" local_result = await self.router_cache.async_get_cache( key=rpm_key, local_only=True ) # check local result first @@ -246,21 +320,26 @@ class LowestTPMLoggingHandler_v2(CustomLogger): for deployment in healthy_deployments: tpm_dict[deployment["model_info"]["id"]] = 0 else: + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + for d in healthy_deployments: ## if healthy deployment not yet used - if d["model_info"]["id"] not in tpm_dict: - tpm_dict[d["model_info"]["id"]] = 0 + tpm_key = f"{d['model_info']['id']}:tpm:{current_minute}" + if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None: + tpm_dict[tpm_key] = 0 all_deployments = tpm_dict - deployment = None for item, item_tpm in all_deployments.items(): ## get the item from model list _deployment = None + item = item.split(":")[0] for m in healthy_deployments: if item == m["model_info"]["id"]: _deployment = m - if _deployment is None: continue # skip to next one @@ -283,7 +362,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger): _deployment_rpm = _deployment.get("model_info", {}).get("rpm") if _deployment_rpm is None: _deployment_rpm = float("inf") - if item_tpm + input_tokens > _deployment_tpm: continue elif (rpm_dict is not None and item in rpm_dict) and ( diff --git a/litellm/tests/test_tpm_rpm_routing_v2.py b/litellm/tests/test_tpm_rpm_routing_v2.py index bb6a9e45b..84728a78c 100644 --- a/litellm/tests/test_tpm_rpm_routing_v2.py +++ b/litellm/tests/test_tpm_rpm_routing_v2.py @@ -1,5 +1,5 @@ #### What this tests #### -# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2' +# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2-v2' import sys, os, asyncio, time, random from datetime import datetime @@ -18,6 +18,7 @@ import litellm from litellm.router_strategy.lowest_tpm_rpm_v2 import ( LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler, ) +from litellm.utils import get_utc_datetime from litellm.caching import DualCache ### UNIT TESTS FOR TPM/RPM ROUTING ### @@ -43,20 +44,23 @@ def test_tpm_rpm_updated(): start_time = time.time() response_obj = {"usage": {"total_tokens": 50}} end_time = time.time() + lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"]) lowest_tpm_logger.log_success_event( response_obj=response_obj, kwargs=kwargs, start_time=start_time, end_time=end_time, ) - current_minute = datetime.now().strftime("%H-%M") - tpm_count_api_key = f"{model_group}:tpm:{current_minute}" - rpm_count_api_key = f"{model_group}:rpm:{current_minute}" - assert ( - response_obj["usage"]["total_tokens"] - == test_cache.get_cache(key=tpm_count_api_key)[deployment_id] + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + tpm_count_api_key = f"{deployment_id}:tpm:{current_minute}" + rpm_count_api_key = f"{deployment_id}:rpm:{current_minute}" + + print(f"tpm_count_api_key={tpm_count_api_key}") + assert response_obj["usage"]["total_tokens"] == test_cache.get_cache( + key=tpm_count_api_key ) - assert 1 == test_cache.get_cache(key=rpm_count_api_key)[deployment_id] + assert 1 == test_cache.get_cache(key=rpm_count_api_key) # test_tpm_rpm_updated() @@ -122,13 +126,6 @@ def test_get_available_deployments(): ) ## CHECK WHAT'S SELECTED ## - print( - lowest_tpm_logger.get_available_deployments( - model_group=model_group, - healthy_deployments=model_list, - input=["Hello world"], - ) - ) assert ( lowest_tpm_logger.get_available_deployments( model_group=model_group, @@ -170,7 +167,7 @@ def test_router_get_available_deployments(): ] router = Router( model_list=model_list, - routing_strategy="usage-based-routing", + routing_strategy="usage-based-routing-v2", set_verbose=False, num_retries=3, ) # type: ignore @@ -189,7 +186,7 @@ def test_router_get_available_deployments(): start_time = time.time() response_obj = {"usage": {"total_tokens": 50}} end_time = time.time() - router.lowesttpm_logger.log_success_event( + router.lowesttpm_logger_v2.log_success_event( response_obj=response_obj, kwargs=kwargs, start_time=start_time, @@ -208,7 +205,7 @@ def test_router_get_available_deployments(): start_time = time.time() response_obj = {"usage": {"total_tokens": 20}} end_time = time.time() - router.lowesttpm_logger.log_success_event( + router.lowesttpm_logger_v2.log_success_event( response_obj=response_obj, kwargs=kwargs, start_time=start_time, @@ -216,7 +213,7 @@ def test_router_get_available_deployments(): ) ## CHECK WHAT'S SELECTED ## - # print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model")) + # print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model")) assert ( router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2" ) @@ -244,7 +241,7 @@ def test_router_skip_rate_limited_deployments(): ] router = Router( model_list=model_list, - routing_strategy="usage-based-routing", + routing_strategy="usage-based-routing-v2", set_verbose=False, num_retries=3, ) # type: ignore @@ -262,7 +259,7 @@ def test_router_skip_rate_limited_deployments(): start_time = time.time() response_obj = {"usage": {"total_tokens": 1439}} end_time = time.time() - router.lowesttpm_logger.log_success_event( + router.lowesttpm_logger_v2.log_success_event( response_obj=response_obj, kwargs=kwargs, start_time=start_time, @@ -270,7 +267,7 @@ def test_router_skip_rate_limited_deployments(): ) ## CHECK WHAT'S SELECTED ## - # print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model")) + # print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model")) try: router.get_available_deployment( model="azure-model", @@ -299,7 +296,7 @@ def test_single_deployment_tpm_zero(): router = litellm.Router( model_list=model_list, - routing_strategy="usage-based-routing", + routing_strategy="usage-based-routing-v2", cache_responses=True, ) @@ -345,7 +342,7 @@ async def test_router_completion_streaming(): ] router = Router( model_list=model_list, - routing_strategy="usage-based-routing", + routing_strategy="usage-based-routing-v2", set_verbose=False, ) # type: ignore @@ -362,8 +359,9 @@ async def test_router_completion_streaming(): if response is not None: ## CALL 3 await asyncio.sleep(1) # let the token update happen - current_minute = datetime.now().strftime("%H-%M") - picked_deployment = router.lowesttpm_logger.get_available_deployments( + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + picked_deployment = router.lowesttpm_logger_v2.get_available_deployments( model_group=model, healthy_deployments=router.healthy_deployments, messages=messages,