From a05f148c17dc1be46c9721767ed348642fd758fe Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 18 Apr 2024 20:01:07 -0700 Subject: [PATCH 1/4] fix(tpm_rpm_routing_v2.py): fix tpm rpm routing --- litellm/caching.py | 63 +++ litellm/router_strategy/lowest_tpm_rpm_v2.py | 20 +- ...outing.py => test_tpm_rpm_routing copy.py} | 0 litellm/tests/test_tpm_rpm_routing_v2.py | 387 ++++++++++++++++++ 4 files changed, 456 insertions(+), 14 deletions(-) rename litellm/tests/{test_tpm_rpm_routing.py => test_tpm_rpm_routing copy.py} (100%) create mode 100644 litellm/tests/test_tpm_rpm_routing_v2.py diff --git a/litellm/caching.py b/litellm/caching.py index bf1d61eec..c8ebd17df 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -89,6 +89,13 @@ class InMemoryCache(BaseCache): return_val.append(val) return return_val + def increment_cache(self, key, value: int, **kwargs) -> int: + # get the value + init_value = self.get_cache(key=key) or 0 + value = init_value + value + self.set_cache(key, value, **kwargs) + return value + async def async_get_cache(self, key, **kwargs): return self.get_cache(key=key, **kwargs) @@ -198,6 +205,38 @@ class RedisCache(BaseCache): f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}" ) + def increment_cache(self, key, value: int, **kwargs) -> int: + _redis_client = self.redis_client + start_time = time.time() + try: + result = _redis_client.incr(name=key, amount=value) + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + ) + ) + return result + except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, duration=_duration, error=e + ) + ) + verbose_logger.error( + "LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + traceback.print_exc() + raise e + async def async_scan_iter(self, pattern: str, count: int = 100) -> list: start_time = time.time() try: @@ -1093,6 +1132,30 @@ class DualCache(BaseCache): except Exception as e: print_verbose(e) + def increment_cache( + self, key, value: int, local_only: bool = False, **kwargs + ) -> int: + """ + Key - the key in cache + + Value - int - the value you want to increment by + + Returns - int - the incremented value + """ + try: + result: int = value + if self.in_memory_cache is not None: + result = self.in_memory_cache.increment_cache(key, value, **kwargs) + + if self.redis_cache is not None and local_only == False: + result = self.redis_cache.increment_cache(key, value, **kwargs) + + return result + except Exception as e: + print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") + traceback.print_exc() + raise e + def get_cache(self, key, local_only: bool = False, **kwargs): # Try to fetch from in-memory cache first try: diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 3babe0345..6d7cc03ef 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -143,26 +143,18 @@ class LowestTPMLoggingHandler_v2(CustomLogger): # Setup values # ------------ dt = get_utc_datetime() - current_minute = dt.strftime("%H-%M") - tpm_key = f"{model_group}:tpm:{current_minute}" - rpm_key = f"{model_group}:rpm:{current_minute}" + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + tpm_key = f"{id}:tpm:{current_minute}" # ------------ # Update usage # ------------ + # update cache ## TPM - request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} - request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens - - self.router_cache.set_cache(key=tpm_key, value=request_count_dict) - - ## RPM - request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} - request_count_dict[id] = request_count_dict.get(id, 0) + 1 - - self.router_cache.set_cache(key=rpm_key, value=request_count_dict) - + self.router_cache.increment_cache(key=tpm_key, value=total_tokens) ### TESTING ### if self.test_flag: self.logged_success += 1 diff --git a/litellm/tests/test_tpm_rpm_routing.py b/litellm/tests/test_tpm_rpm_routing copy.py similarity index 100% rename from litellm/tests/test_tpm_rpm_routing.py rename to litellm/tests/test_tpm_rpm_routing copy.py diff --git a/litellm/tests/test_tpm_rpm_routing_v2.py b/litellm/tests/test_tpm_rpm_routing_v2.py new file mode 100644 index 000000000..bb6a9e45b --- /dev/null +++ b/litellm/tests/test_tpm_rpm_routing_v2.py @@ -0,0 +1,387 @@ +#### What this tests #### +# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2' + +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +from litellm import Router +import litellm +from litellm.router_strategy.lowest_tpm_rpm_v2 import ( + LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler, +) +from litellm.caching import DualCache + +### UNIT TESTS FOR TPM/RPM ROUTING ### + + +def test_tpm_rpm_updated(): + test_cache = DualCache() + model_list = [] + lowest_tpm_logger = LowestTPMLoggingHandler( + router_cache=test_cache, model_list=model_list + ) + model_group = "gpt-3.5-turbo" + deployment_id = "1234" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + end_time = time.time() + lowest_tpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + current_minute = datetime.now().strftime("%H-%M") + tpm_count_api_key = f"{model_group}:tpm:{current_minute}" + rpm_count_api_key = f"{model_group}:rpm:{current_minute}" + assert ( + response_obj["usage"]["total_tokens"] + == test_cache.get_cache(key=tpm_count_api_key)[deployment_id] + ) + assert 1 == test_cache.get_cache(key=rpm_count_api_key)[deployment_id] + + +# test_tpm_rpm_updated() + + +def test_get_available_deployments(): + test_cache = DualCache() + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "azure/chatgpt-v-2"}, + "model_info": {"id": "1234"}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "azure/chatgpt-v-2"}, + "model_info": {"id": "5678"}, + }, + ] + lowest_tpm_logger = LowestTPMLoggingHandler( + router_cache=test_cache, model_list=model_list + ) + model_group = "gpt-3.5-turbo" + ## DEPLOYMENT 1 ## + deployment_id = "1234" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + end_time = time.time() + lowest_tpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + ## DEPLOYMENT 2 ## + deployment_id = "5678" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 20}} + end_time = time.time() + lowest_tpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + + ## CHECK WHAT'S SELECTED ## + print( + lowest_tpm_logger.get_available_deployments( + model_group=model_group, + healthy_deployments=model_list, + input=["Hello world"], + ) + ) + assert ( + lowest_tpm_logger.get_available_deployments( + model_group=model_group, + healthy_deployments=model_list, + input=["Hello world"], + )["model_info"]["id"] + == "5678" + ) + + +# test_get_available_deployments() + + +def test_router_get_available_deployments(): + """ + Test if routers 'get_available_deployments' returns the least busy deployment + """ + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "rpm": 1440, + }, + "model_info": {"id": 1}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_EUROPE_API_KEY", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 2}, + }, + ] + router = Router( + model_list=model_list, + routing_strategy="usage-based-routing", + set_verbose=False, + num_retries=3, + ) # type: ignore + + print(f"router id's: {router.get_model_ids()}") + ## DEPLOYMENT 1 ## + deployment_id = 1 + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "azure-model", + }, + "model_info": {"id": 1}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + end_time = time.time() + router.lowesttpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + ## DEPLOYMENT 2 ## + deployment_id = 2 + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "azure-model", + }, + "model_info": {"id": 2}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 20}} + end_time = time.time() + router.lowesttpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + + ## CHECK WHAT'S SELECTED ## + # print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model")) + assert ( + router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2" + ) + + +# test_get_available_deployments() +# test_router_get_available_deployments() + + +def test_router_skip_rate_limited_deployments(): + """ + Test if routers 'get_available_deployments' raises No Models Available error if max tpm would be reached by message + """ + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "tpm": 1440, + }, + "model_info": {"id": 1}, + }, + ] + router = Router( + model_list=model_list, + routing_strategy="usage-based-routing", + set_verbose=False, + num_retries=3, + ) # type: ignore + + ## DEPLOYMENT 1 ## + deployment_id = 1 + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "azure-model", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 1439}} + end_time = time.time() + router.lowesttpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + + ## CHECK WHAT'S SELECTED ## + # print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model")) + try: + router.get_available_deployment( + model="azure-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + pytest.fail(f"Should have raised No Models Available error") + except Exception as e: + print(f"An exception occurred! {str(e)}") + + +def test_single_deployment_tpm_zero(): + import litellm + import os + from datetime import datetime + + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + "tpm": 0, + }, + } + ] + + router = litellm.Router( + model_list=model_list, + routing_strategy="usage-based-routing", + cache_responses=True, + ) + + model = "gpt-3.5-turbo" + messages = [{"content": "Hello, how are you?", "role": "user"}] + try: + router.get_available_deployment( + model=model, + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + pytest.fail(f"Should have raised No Models Available error") + except Exception as e: + print(f"it worked - {str(e)}! \n{traceback.format_exc()}") + + +@pytest.mark.asyncio +async def test_router_completion_streaming(): + messages = [ + {"role": "user", "content": "Hello, can you generate a 500 words poem?"} + ] + model = "azure-model" + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "rpm": 1440, + }, + "model_info": {"id": 1}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_EUROPE_API_KEY", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 2}, + }, + ] + router = Router( + model_list=model_list, + routing_strategy="usage-based-routing", + set_verbose=False, + ) # type: ignore + + ### Make 3 calls, test if 3rd call goes to lowest tpm deployment + + ## CALL 1+2 + tasks = [] + response = None + final_response = None + for _ in range(2): + tasks.append(router.acompletion(model=model, messages=messages)) + response = await asyncio.gather(*tasks) + + if response is not None: + ## CALL 3 + await asyncio.sleep(1) # let the token update happen + current_minute = datetime.now().strftime("%H-%M") + picked_deployment = router.lowesttpm_logger.get_available_deployments( + model_group=model, + healthy_deployments=router.healthy_deployments, + messages=messages, + ) + final_response = await router.acompletion(model=model, messages=messages) + print(f"min deployment id: {picked_deployment}") + tpm_key = f"{model}:tpm:{current_minute}" + rpm_key = f"{model}:rpm:{current_minute}" + + tpm_dict = router.cache.get_cache(key=tpm_key) + print(f"tpm_dict: {tpm_dict}") + rpm_dict = router.cache.get_cache(key=rpm_key) + print(f"rpm_dict: {rpm_dict}") + print(f"model id: {final_response._hidden_params['model_id']}") + assert ( + final_response._hidden_params["model_id"] + == picked_deployment["model_info"]["id"] + ) + + +# asyncio.run(test_router_completion_streaming()) From 81573b2dd99ecd827247bf27b1883f1d2cd06790 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 18 Apr 2024 21:38:00 -0700 Subject: [PATCH 2/4] fix(test_lowest_tpm_rpm_routing_v2.py): unit testing for usage-based-routing-v2 --- litellm/_service_logger.py | 6 +- litellm/caching.py | 33 ++++--- litellm/integrations/custom_logger.py | 13 ++- litellm/router.py | 28 ++++-- litellm/router_strategy/lowest_tpm_rpm_v2.py | 94 ++++++++++++++++++-- litellm/tests/test_tpm_rpm_routing_v2.py | 50 +++++------ 6 files changed, 171 insertions(+), 53 deletions(-) diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index 0c6996b10..dc6f35642 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -21,7 +21,9 @@ class ServiceLogging(CustomLogger): if "prometheus_system" in litellm.service_callback: self.prometheusServicesLogger = PrometheusServicesLogger() - def service_success_hook(self, service: ServiceTypes, duration: float): + def service_success_hook( + self, service: ServiceTypes, duration: float, call_type: str + ): """ [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy). """ @@ -29,7 +31,7 @@ class ServiceLogging(CustomLogger): self.mock_testing_sync_success_hook += 1 def service_failure_hook( - self, service: ServiceTypes, duration: float, error: Exception + self, service: ServiceTypes, duration: float, error: Exception, call_type: str ): """ [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy). diff --git a/litellm/caching.py b/litellm/caching.py index c8ebd17df..f6d826de5 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -217,6 +217,7 @@ class RedisCache(BaseCache): self.service_logger_obj.service_success_hook( service=ServiceTypes.REDIS, duration=_duration, + call_type="increment_cache", ) ) return result @@ -226,11 +227,14 @@ class RedisCache(BaseCache): _duration = end_time - start_time asyncio.create_task( self.service_logger_obj.async_service_failure_hook( - service=ServiceTypes.REDIS, duration=_duration, error=e + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="increment_cache", ) ) verbose_logger.error( - "LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s", + "LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s", str(e), value, ) @@ -278,6 +282,9 @@ class RedisCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): start_time = time.time() + print_verbose( + f"Set Async Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}" + ) try: _redis_client = self.init_async_client() except Exception as e: @@ -341,6 +348,10 @@ class RedisCache(BaseCache): """ _redis_client = self.init_async_client() start_time = time.time() + + print_verbose( + f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}" + ) try: async with _redis_client as redis_client: async with redis_client.pipeline(transaction=True) as pipe: @@ -1261,7 +1272,6 @@ class DualCache(BaseCache): print_verbose(f"in_memory_result: {in_memory_result}") if in_memory_result is not None: result = in_memory_result - if None in result and self.redis_cache is not None and local_only == False: """ - for the none values in the result @@ -1277,14 +1287,12 @@ class DualCache(BaseCache): if redis_result is not None: # Update in-memory cache with the value from Redis - for key in redis_result: - await self.in_memory_cache.async_set_cache( - key, redis_result[key], **kwargs - ) - - sublist_dict = dict(zip(sublist_keys, redis_result)) - - for key, value in sublist_dict.items(): + for key, value in redis_result.items(): + if value is not None: + await self.in_memory_cache.async_set_cache( + key, redis_result[key], **kwargs + ) + for key, value in redis_result.items(): result[sublist_keys.index(key)] = value print_verbose(f"async batch get cache: cache result: {result}") @@ -1293,6 +1301,9 @@ class DualCache(BaseCache): traceback.print_exc() async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): + print_verbose( + f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}" + ) try: if self.in_memory_cache is not None: await self.in_memory_cache.async_set_cache(key, value, **kwargs) diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 503b3ff9d..b288036ad 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -6,7 +6,7 @@ import requests from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache -from typing import Literal, Union +from typing import Literal, Union, Optional dotenv.load_dotenv() # Loading env variables using dotenv import traceback @@ -46,6 +46,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): pass + #### PRE-CALL CHECKS - router/proxy only #### + """ + Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks). + """ + + async def async_pre_call_check(self, deployment: dict) -> Optional[dict]: + pass + + def pre_call_check(self, deployment: dict) -> Optional[dict]: + pass + #### CALL HOOKS - proxy only #### """ Control the modify incoming / outgoung data before calling the model diff --git a/litellm/router.py b/litellm/router.py index abbd6343b..fcb5424f6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -31,6 +31,7 @@ import copy from litellm._logging import verbose_router_logger import logging from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors +from litellm.integrations.custom_logger import CustomLogger class Router: @@ -492,18 +493,18 @@ class Router: deployment=deployment, kwargs=kwargs, client_type="rpm_client" ) - if ( - rpm_semaphore is not None - and isinstance(rpm_semaphore, asyncio.Semaphore) - and self.routing_strategy == "usage-based-routing-v2" + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore ): async with rpm_semaphore: """ - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ - await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment) + await self.routing_strategy_pre_call_checks(deployment=deployment) response = await _response else: + await self.routing_strategy_pre_call_checks(deployment=deployment) response = await _response self.success_calls[model_name] += 1 @@ -1712,6 +1713,22 @@ class Router: verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") return cooldown_models + async def routing_strategy_pre_call_checks(self, deployment: dict): + """ + For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore. + + -> makes the calls concurrency-safe, when rpm limits are set for a deployment + + Returns: + - None + + Raises: + - Rate Limit Exception - If the deployment is over it's tpm/rpm limits + """ + for _callback in litellm.callbacks: + if isinstance(_callback, CustomLogger): + response = await _callback.async_pre_call_check(deployment) + def set_client(self, model: dict): """ - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 @@ -2700,6 +2717,7 @@ class Router: verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" ) + return deployment def get_available_deployment( diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 6d7cc03ef..022bf5ffe 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -39,7 +39,81 @@ class LowestTPMLoggingHandler_v2(CustomLogger): self.router_cache = router_cache self.model_list = model_list - async def pre_call_rpm_check(self, deployment: dict) -> dict: + def pre_call_check(self, deployment: Dict) -> Dict | None: + """ + Pre-call check + update model rpm + + Returns - deployment + + Raises - RateLimitError if deployment over defined RPM limit + """ + try: + + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + model_id = deployment.get("model_info", {}).get("id") + rpm_key = f"{model_id}:rpm:{current_minute}" + local_result = self.router_cache.get_cache( + key=rpm_key, local_only=True + ) # check local result first + + deployment_rpm = None + if deployment_rpm is None: + deployment_rpm = deployment.get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("litellm_params", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("model_info", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = float("inf") + + if local_result is not None and local_result >= deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, local_result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + local_result, + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + else: + # if local result below limit, check redis ## prevent unnecessary redis checks + result = self.router_cache.increment_cache(key=rpm_key, value=1) + if result is not None and result > deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + result, + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return deployment + except Exception as e: + if isinstance(e, litellm.RateLimitError): + raise e + return deployment # don't fail calls if eg. redis fails to connect + + async def async_pre_call_check(self, deployment: Dict) -> Dict | None: """ Pre-call check + update model rpm - Used inside semaphore @@ -58,8 +132,8 @@ class LowestTPMLoggingHandler_v2(CustomLogger): # ------------ dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") - model_group = deployment.get("model_name", "") - rpm_key = f"{model_group}:rpm:{current_minute}" + model_id = deployment.get("model_info", {}).get("id") + rpm_key = f"{model_id}:rpm:{current_minute}" local_result = await self.router_cache.async_get_cache( key=rpm_key, local_only=True ) # check local result first @@ -246,21 +320,26 @@ class LowestTPMLoggingHandler_v2(CustomLogger): for deployment in healthy_deployments: tpm_dict[deployment["model_info"]["id"]] = 0 else: + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + for d in healthy_deployments: ## if healthy deployment not yet used - if d["model_info"]["id"] not in tpm_dict: - tpm_dict[d["model_info"]["id"]] = 0 + tpm_key = f"{d['model_info']['id']}:tpm:{current_minute}" + if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None: + tpm_dict[tpm_key] = 0 all_deployments = tpm_dict - deployment = None for item, item_tpm in all_deployments.items(): ## get the item from model list _deployment = None + item = item.split(":")[0] for m in healthy_deployments: if item == m["model_info"]["id"]: _deployment = m - if _deployment is None: continue # skip to next one @@ -283,7 +362,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger): _deployment_rpm = _deployment.get("model_info", {}).get("rpm") if _deployment_rpm is None: _deployment_rpm = float("inf") - if item_tpm + input_tokens > _deployment_tpm: continue elif (rpm_dict is not None and item in rpm_dict) and ( diff --git a/litellm/tests/test_tpm_rpm_routing_v2.py b/litellm/tests/test_tpm_rpm_routing_v2.py index bb6a9e45b..84728a78c 100644 --- a/litellm/tests/test_tpm_rpm_routing_v2.py +++ b/litellm/tests/test_tpm_rpm_routing_v2.py @@ -1,5 +1,5 @@ #### What this tests #### -# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2' +# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2-v2' import sys, os, asyncio, time, random from datetime import datetime @@ -18,6 +18,7 @@ import litellm from litellm.router_strategy.lowest_tpm_rpm_v2 import ( LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler, ) +from litellm.utils import get_utc_datetime from litellm.caching import DualCache ### UNIT TESTS FOR TPM/RPM ROUTING ### @@ -43,20 +44,23 @@ def test_tpm_rpm_updated(): start_time = time.time() response_obj = {"usage": {"total_tokens": 50}} end_time = time.time() + lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"]) lowest_tpm_logger.log_success_event( response_obj=response_obj, kwargs=kwargs, start_time=start_time, end_time=end_time, ) - current_minute = datetime.now().strftime("%H-%M") - tpm_count_api_key = f"{model_group}:tpm:{current_minute}" - rpm_count_api_key = f"{model_group}:rpm:{current_minute}" - assert ( - response_obj["usage"]["total_tokens"] - == test_cache.get_cache(key=tpm_count_api_key)[deployment_id] + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + tpm_count_api_key = f"{deployment_id}:tpm:{current_minute}" + rpm_count_api_key = f"{deployment_id}:rpm:{current_minute}" + + print(f"tpm_count_api_key={tpm_count_api_key}") + assert response_obj["usage"]["total_tokens"] == test_cache.get_cache( + key=tpm_count_api_key ) - assert 1 == test_cache.get_cache(key=rpm_count_api_key)[deployment_id] + assert 1 == test_cache.get_cache(key=rpm_count_api_key) # test_tpm_rpm_updated() @@ -122,13 +126,6 @@ def test_get_available_deployments(): ) ## CHECK WHAT'S SELECTED ## - print( - lowest_tpm_logger.get_available_deployments( - model_group=model_group, - healthy_deployments=model_list, - input=["Hello world"], - ) - ) assert ( lowest_tpm_logger.get_available_deployments( model_group=model_group, @@ -170,7 +167,7 @@ def test_router_get_available_deployments(): ] router = Router( model_list=model_list, - routing_strategy="usage-based-routing", + routing_strategy="usage-based-routing-v2", set_verbose=False, num_retries=3, ) # type: ignore @@ -189,7 +186,7 @@ def test_router_get_available_deployments(): start_time = time.time() response_obj = {"usage": {"total_tokens": 50}} end_time = time.time() - router.lowesttpm_logger.log_success_event( + router.lowesttpm_logger_v2.log_success_event( response_obj=response_obj, kwargs=kwargs, start_time=start_time, @@ -208,7 +205,7 @@ def test_router_get_available_deployments(): start_time = time.time() response_obj = {"usage": {"total_tokens": 20}} end_time = time.time() - router.lowesttpm_logger.log_success_event( + router.lowesttpm_logger_v2.log_success_event( response_obj=response_obj, kwargs=kwargs, start_time=start_time, @@ -216,7 +213,7 @@ def test_router_get_available_deployments(): ) ## CHECK WHAT'S SELECTED ## - # print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model")) + # print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model")) assert ( router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2" ) @@ -244,7 +241,7 @@ def test_router_skip_rate_limited_deployments(): ] router = Router( model_list=model_list, - routing_strategy="usage-based-routing", + routing_strategy="usage-based-routing-v2", set_verbose=False, num_retries=3, ) # type: ignore @@ -262,7 +259,7 @@ def test_router_skip_rate_limited_deployments(): start_time = time.time() response_obj = {"usage": {"total_tokens": 1439}} end_time = time.time() - router.lowesttpm_logger.log_success_event( + router.lowesttpm_logger_v2.log_success_event( response_obj=response_obj, kwargs=kwargs, start_time=start_time, @@ -270,7 +267,7 @@ def test_router_skip_rate_limited_deployments(): ) ## CHECK WHAT'S SELECTED ## - # print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model")) + # print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model")) try: router.get_available_deployment( model="azure-model", @@ -299,7 +296,7 @@ def test_single_deployment_tpm_zero(): router = litellm.Router( model_list=model_list, - routing_strategy="usage-based-routing", + routing_strategy="usage-based-routing-v2", cache_responses=True, ) @@ -345,7 +342,7 @@ async def test_router_completion_streaming(): ] router = Router( model_list=model_list, - routing_strategy="usage-based-routing", + routing_strategy="usage-based-routing-v2", set_verbose=False, ) # type: ignore @@ -362,8 +359,9 @@ async def test_router_completion_streaming(): if response is not None: ## CALL 3 await asyncio.sleep(1) # let the token update happen - current_minute = datetime.now().strftime("%H-%M") - picked_deployment = router.lowesttpm_logger.get_available_deployments( + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + picked_deployment = router.lowesttpm_logger_v2.get_available_deployments( model_group=model, healthy_deployments=router.healthy_deployments, messages=messages, From 3b9e2a58e2d9a01bd100baac90787c41d4246646 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 18 Apr 2024 21:42:35 -0700 Subject: [PATCH 3/4] fix(lowest_tpm_rpm_v2.py): ensure backwards compatibility for python 3.8 --- litellm/router_strategy/lowest_tpm_rpm_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 022bf5ffe..4f6364c2b 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -39,7 +39,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): self.router_cache = router_cache self.model_list = model_list - def pre_call_check(self, deployment: Dict) -> Dict | None: + def pre_call_check(self, deployment: Dict) -> Optional[Dict]: """ Pre-call check + update model rpm @@ -113,7 +113,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): raise e return deployment # don't fail calls if eg. redis fails to connect - async def async_pre_call_check(self, deployment: Dict) -> Dict | None: + async def async_pre_call_check(self, deployment: Dict) -> Optional[Dict]: """ Pre-call check + update model rpm - Used inside semaphore From 9c42c847a529696a76adbaa75f063b78353f9b96 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 18 Apr 2024 21:54:25 -0700 Subject: [PATCH 4/4] fix(router.py): instrument pre-call-checks for all openai endpoints --- litellm/caching.py | 3 - litellm/router.py | 132 +++++++++++++++++++++-- litellm/tests/test_tpm_rpm_routing_v2.py | 5 + 3 files changed, 130 insertions(+), 10 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index f6d826de5..d73112d21 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -282,9 +282,6 @@ class RedisCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): start_time = time.time() - print_verbose( - f"Set Async Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}" - ) try: _redis_client = self.init_async_client() except Exception as e: diff --git a/litellm/router.py b/litellm/router.py index fcb5424f6..e393345da 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -379,6 +379,9 @@ class Router: else: model_client = potential_model_client + ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) + self.routing_strategy_pre_call_checks(deployment=deployment) + response = litellm.completion( **{ **data, @@ -391,6 +394,7 @@ class Router: verbose_router_logger.info( f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m" ) + return response except Exception as e: verbose_router_logger.info( @@ -501,10 +505,12 @@ class Router: - Check rpm limits before making the call - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ - await self.routing_strategy_pre_call_checks(deployment=deployment) + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) response = await _response else: - await self.routing_strategy_pre_call_checks(deployment=deployment) + await self.async_routing_strategy_pre_call_checks(deployment=deployment) response = await _response self.success_calls[model_name] += 1 @@ -579,6 +585,10 @@ class Router: model_client = potential_model_client self.total_calls[model_name] += 1 + + ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) + self.routing_strategy_pre_call_checks(deployment=deployment) + response = litellm.image_generation( **{ **data, @@ -657,7 +667,7 @@ class Router: model_client = potential_model_client self.total_calls[model_name] += 1 - response = await litellm.aimage_generation( + response = litellm.aimage_generation( **{ **data, "prompt": prompt, @@ -666,6 +676,28 @@ class Router: **kwargs, } ) + + ### CONCURRENCY-SAFE RPM CHECKS ### + rpm_semaphore = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="rpm_client" + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks(deployment=deployment) + response = await response + self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m" @@ -757,7 +789,7 @@ class Router: model_client = potential_model_client self.total_calls[model_name] += 1 - response = await litellm.atranscription( + response = litellm.atranscription( **{ **data, "file": file, @@ -766,6 +798,28 @@ class Router: **kwargs, } ) + + ### CONCURRENCY-SAFE RPM CHECKS ### + rpm_semaphore = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="rpm_client" + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks(deployment=deployment) + response = await response + self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m" @@ -979,7 +1033,8 @@ class Router: else: model_client = potential_model_client self.total_calls[model_name] += 1 - response = await litellm.atext_completion( + + response = litellm.atext_completion( **{ **data, "prompt": prompt, @@ -989,6 +1044,27 @@ class Router: **kwargs, } ) + + rpm_semaphore = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="rpm_client" + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks(deployment=deployment) + response = await response + self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m" @@ -1063,6 +1139,10 @@ class Router: model_client = potential_model_client self.total_calls[model_name] += 1 + + ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) + self.routing_strategy_pre_call_checks(deployment=deployment) + response = litellm.embedding( **{ **data, @@ -1147,7 +1227,7 @@ class Router: model_client = potential_model_client self.total_calls[model_name] += 1 - response = await litellm.aembedding( + response = litellm.aembedding( **{ **data, "input": input, @@ -1156,6 +1236,28 @@ class Router: **kwargs, } ) + + ### CONCURRENCY-SAFE RPM CHECKS ### + rpm_semaphore = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="rpm_client" + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks(deployment=deployment) + response = await response + self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m" @@ -1713,7 +1815,23 @@ class Router: verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") return cooldown_models - async def routing_strategy_pre_call_checks(self, deployment: dict): + def routing_strategy_pre_call_checks(self, deployment: dict): + """ + Mimics 'async_routing_strategy_pre_call_checks' + + Ensures consistent update rpm implementation for 'usage-based-routing-v2' + + Returns: + - None + + Raises: + - Rate Limit Exception - If the deployment is over it's tpm/rpm limits + """ + for _callback in litellm.callbacks: + if isinstance(_callback, CustomLogger): + response = _callback.pre_call_check(deployment) + + async def async_routing_strategy_pre_call_checks(self, deployment: dict): """ For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore. diff --git a/litellm/tests/test_tpm_rpm_routing_v2.py b/litellm/tests/test_tpm_rpm_routing_v2.py index 84728a78c..4a0256f6a 100644 --- a/litellm/tests/test_tpm_rpm_routing_v2.py +++ b/litellm/tests/test_tpm_rpm_routing_v2.py @@ -383,3 +383,8 @@ async def test_router_completion_streaming(): # asyncio.run(test_router_completion_streaming()) + +""" +- Unit test for sync 'pre_call_checks' +- Unit test for async 'async_pre_call_checks' +"""