diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index 0c6996b10..dc6f35642 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -21,7 +21,9 @@ class ServiceLogging(CustomLogger): if "prometheus_system" in litellm.service_callback: self.prometheusServicesLogger = PrometheusServicesLogger() - def service_success_hook(self, service: ServiceTypes, duration: float): + def service_success_hook( + self, service: ServiceTypes, duration: float, call_type: str + ): """ [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy). """ @@ -29,7 +31,7 @@ class ServiceLogging(CustomLogger): self.mock_testing_sync_success_hook += 1 def service_failure_hook( - self, service: ServiceTypes, duration: float, error: Exception + self, service: ServiceTypes, duration: float, error: Exception, call_type: str ): """ [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy). diff --git a/litellm/caching.py b/litellm/caching.py index bf1d61eec..d73112d21 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -89,6 +89,13 @@ class InMemoryCache(BaseCache): return_val.append(val) return return_val + def increment_cache(self, key, value: int, **kwargs) -> int: + # get the value + init_value = self.get_cache(key=key) or 0 + value = init_value + value + self.set_cache(key, value, **kwargs) + return value + async def async_get_cache(self, key, **kwargs): return self.get_cache(key=key, **kwargs) @@ -198,6 +205,42 @@ class RedisCache(BaseCache): f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}" ) + def increment_cache(self, key, value: int, **kwargs) -> int: + _redis_client = self.redis_client + start_time = time.time() + try: + result = _redis_client.incr(name=key, amount=value) + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="increment_cache", + ) + ) + return result + except Exception as e: + ## LOGGING ## + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="increment_cache", + ) + ) + verbose_logger.error( + "LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + traceback.print_exc() + raise e + async def async_scan_iter(self, pattern: str, count: int = 100) -> list: start_time = time.time() try: @@ -302,6 +345,10 @@ class RedisCache(BaseCache): """ _redis_client = self.init_async_client() start_time = time.time() + + print_verbose( + f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}" + ) try: async with _redis_client as redis_client: async with redis_client.pipeline(transaction=True) as pipe: @@ -1093,6 +1140,30 @@ class DualCache(BaseCache): except Exception as e: print_verbose(e) + def increment_cache( + self, key, value: int, local_only: bool = False, **kwargs + ) -> int: + """ + Key - the key in cache + + Value - int - the value you want to increment by + + Returns - int - the incremented value + """ + try: + result: int = value + if self.in_memory_cache is not None: + result = self.in_memory_cache.increment_cache(key, value, **kwargs) + + if self.redis_cache is not None and local_only == False: + result = self.redis_cache.increment_cache(key, value, **kwargs) + + return result + except Exception as e: + print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") + traceback.print_exc() + raise e + def get_cache(self, key, local_only: bool = False, **kwargs): # Try to fetch from in-memory cache first try: @@ -1198,7 +1269,6 @@ class DualCache(BaseCache): print_verbose(f"in_memory_result: {in_memory_result}") if in_memory_result is not None: result = in_memory_result - if None in result and self.redis_cache is not None and local_only == False: """ - for the none values in the result @@ -1214,14 +1284,12 @@ class DualCache(BaseCache): if redis_result is not None: # Update in-memory cache with the value from Redis - for key in redis_result: - await self.in_memory_cache.async_set_cache( - key, redis_result[key], **kwargs - ) - - sublist_dict = dict(zip(sublist_keys, redis_result)) - - for key, value in sublist_dict.items(): + for key, value in redis_result.items(): + if value is not None: + await self.in_memory_cache.async_set_cache( + key, redis_result[key], **kwargs + ) + for key, value in redis_result.items(): result[sublist_keys.index(key)] = value print_verbose(f"async batch get cache: cache result: {result}") @@ -1230,6 +1298,9 @@ class DualCache(BaseCache): traceback.print_exc() async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): + print_verbose( + f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}" + ) try: if self.in_memory_cache is not None: await self.in_memory_cache.async_set_cache(key, value, **kwargs) diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 503b3ff9d..b288036ad 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -6,7 +6,7 @@ import requests from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache -from typing import Literal, Union +from typing import Literal, Union, Optional dotenv.load_dotenv() # Loading env variables using dotenv import traceback @@ -46,6 +46,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): pass + #### PRE-CALL CHECKS - router/proxy only #### + """ + Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks). + """ + + async def async_pre_call_check(self, deployment: dict) -> Optional[dict]: + pass + + def pre_call_check(self, deployment: dict) -> Optional[dict]: + pass + #### CALL HOOKS - proxy only #### """ Control the modify incoming / outgoung data before calling the model diff --git a/litellm/router.py b/litellm/router.py index e67507318..8145ef619 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -31,6 +31,7 @@ import copy from litellm._logging import verbose_router_logger import logging from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors +from litellm.integrations.custom_logger import CustomLogger class Router: @@ -379,6 +380,9 @@ class Router: else: model_client = potential_model_client + ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) + self.routing_strategy_pre_call_checks(deployment=deployment) + response = litellm.completion( **{ **data, @@ -391,6 +395,7 @@ class Router: verbose_router_logger.info( f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m" ) + return response except Exception as e: verbose_router_logger.info( @@ -494,18 +499,20 @@ class Router: deployment=deployment, kwargs=kwargs, client_type="rpm_client" ) - if ( - rpm_semaphore is not None - and isinstance(rpm_semaphore, asyncio.Semaphore) - and self.routing_strategy == "usage-based-routing-v2" + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore ): async with rpm_semaphore: """ - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ - await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment) + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) response = await _response else: + await self.async_routing_strategy_pre_call_checks(deployment=deployment) response = await _response self.success_calls[model_name] += 1 @@ -580,6 +587,10 @@ class Router: model_client = potential_model_client self.total_calls[model_name] += 1 + + ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) + self.routing_strategy_pre_call_checks(deployment=deployment) + response = litellm.image_generation( **{ **data, @@ -658,7 +669,7 @@ class Router: model_client = potential_model_client self.total_calls[model_name] += 1 - response = await litellm.aimage_generation( + response = litellm.aimage_generation( **{ **data, "prompt": prompt, @@ -667,6 +678,28 @@ class Router: **kwargs, } ) + + ### CONCURRENCY-SAFE RPM CHECKS ### + rpm_semaphore = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="rpm_client" + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks(deployment=deployment) + response = await response + self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m" @@ -758,7 +791,7 @@ class Router: model_client = potential_model_client self.total_calls[model_name] += 1 - response = await litellm.atranscription( + response = litellm.atranscription( **{ **data, "file": file, @@ -767,6 +800,28 @@ class Router: **kwargs, } ) + + ### CONCURRENCY-SAFE RPM CHECKS ### + rpm_semaphore = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="rpm_client" + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks(deployment=deployment) + response = await response + self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m" @@ -981,7 +1036,8 @@ class Router: else: model_client = potential_model_client self.total_calls[model_name] += 1 - response = await litellm.atext_completion( + + response = litellm.atext_completion( **{ **data, "prompt": prompt, @@ -991,6 +1047,27 @@ class Router: **kwargs, } ) + + rpm_semaphore = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="rpm_client" + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks(deployment=deployment) + response = await response + self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m" @@ -1065,6 +1142,10 @@ class Router: model_client = potential_model_client self.total_calls[model_name] += 1 + + ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) + self.routing_strategy_pre_call_checks(deployment=deployment) + response = litellm.embedding( **{ **data, @@ -1150,7 +1231,7 @@ class Router: model_client = potential_model_client self.total_calls[model_name] += 1 - response = await litellm.aembedding( + response = litellm.aembedding( **{ **data, "input": input, @@ -1159,6 +1240,28 @@ class Router: **kwargs, } ) + + ### CONCURRENCY-SAFE RPM CHECKS ### + rpm_semaphore = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="rpm_client" + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) + response = await response + else: + await self.async_routing_strategy_pre_call_checks(deployment=deployment) + response = await response + self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m" @@ -1716,6 +1819,38 @@ class Router: verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") return cooldown_models + def routing_strategy_pre_call_checks(self, deployment: dict): + """ + Mimics 'async_routing_strategy_pre_call_checks' + + Ensures consistent update rpm implementation for 'usage-based-routing-v2' + + Returns: + - None + + Raises: + - Rate Limit Exception - If the deployment is over it's tpm/rpm limits + """ + for _callback in litellm.callbacks: + if isinstance(_callback, CustomLogger): + response = _callback.pre_call_check(deployment) + + async def async_routing_strategy_pre_call_checks(self, deployment: dict): + """ + For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore. + + -> makes the calls concurrency-safe, when rpm limits are set for a deployment + + Returns: + - None + + Raises: + - Rate Limit Exception - If the deployment is over it's tpm/rpm limits + """ + for _callback in litellm.callbacks: + if isinstance(_callback, CustomLogger): + response = await _callback.async_pre_call_check(deployment) + def set_client(self, model: dict): """ - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 @@ -2704,6 +2839,7 @@ class Router: verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" ) + return deployment def get_available_deployment( diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 3babe0345..4f6364c2b 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -39,7 +39,81 @@ class LowestTPMLoggingHandler_v2(CustomLogger): self.router_cache = router_cache self.model_list = model_list - async def pre_call_rpm_check(self, deployment: dict) -> dict: + def pre_call_check(self, deployment: Dict) -> Optional[Dict]: + """ + Pre-call check + update model rpm + + Returns - deployment + + Raises - RateLimitError if deployment over defined RPM limit + """ + try: + + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + model_id = deployment.get("model_info", {}).get("id") + rpm_key = f"{model_id}:rpm:{current_minute}" + local_result = self.router_cache.get_cache( + key=rpm_key, local_only=True + ) # check local result first + + deployment_rpm = None + if deployment_rpm is None: + deployment_rpm = deployment.get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("litellm_params", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("model_info", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = float("inf") + + if local_result is not None and local_result >= deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, local_result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + local_result, + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + else: + # if local result below limit, check redis ## prevent unnecessary redis checks + result = self.router_cache.increment_cache(key=rpm_key, value=1) + if result is not None and result > deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + result, + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return deployment + except Exception as e: + if isinstance(e, litellm.RateLimitError): + raise e + return deployment # don't fail calls if eg. redis fails to connect + + async def async_pre_call_check(self, deployment: Dict) -> Optional[Dict]: """ Pre-call check + update model rpm - Used inside semaphore @@ -58,8 +132,8 @@ class LowestTPMLoggingHandler_v2(CustomLogger): # ------------ dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") - model_group = deployment.get("model_name", "") - rpm_key = f"{model_group}:rpm:{current_minute}" + model_id = deployment.get("model_info", {}).get("id") + rpm_key = f"{model_id}:rpm:{current_minute}" local_result = await self.router_cache.async_get_cache( key=rpm_key, local_only=True ) # check local result first @@ -143,26 +217,18 @@ class LowestTPMLoggingHandler_v2(CustomLogger): # Setup values # ------------ dt = get_utc_datetime() - current_minute = dt.strftime("%H-%M") - tpm_key = f"{model_group}:tpm:{current_minute}" - rpm_key = f"{model_group}:rpm:{current_minute}" + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + tpm_key = f"{id}:tpm:{current_minute}" # ------------ # Update usage # ------------ + # update cache ## TPM - request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} - request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens - - self.router_cache.set_cache(key=tpm_key, value=request_count_dict) - - ## RPM - request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} - request_count_dict[id] = request_count_dict.get(id, 0) + 1 - - self.router_cache.set_cache(key=rpm_key, value=request_count_dict) - + self.router_cache.increment_cache(key=tpm_key, value=total_tokens) ### TESTING ### if self.test_flag: self.logged_success += 1 @@ -254,21 +320,26 @@ class LowestTPMLoggingHandler_v2(CustomLogger): for deployment in healthy_deployments: tpm_dict[deployment["model_info"]["id"]] = 0 else: + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + for d in healthy_deployments: ## if healthy deployment not yet used - if d["model_info"]["id"] not in tpm_dict: - tpm_dict[d["model_info"]["id"]] = 0 + tpm_key = f"{d['model_info']['id']}:tpm:{current_minute}" + if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None: + tpm_dict[tpm_key] = 0 all_deployments = tpm_dict - deployment = None for item, item_tpm in all_deployments.items(): ## get the item from model list _deployment = None + item = item.split(":")[0] for m in healthy_deployments: if item == m["model_info"]["id"]: _deployment = m - if _deployment is None: continue # skip to next one @@ -291,7 +362,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger): _deployment_rpm = _deployment.get("model_info", {}).get("rpm") if _deployment_rpm is None: _deployment_rpm = float("inf") - if item_tpm + input_tokens > _deployment_tpm: continue elif (rpm_dict is not None and item in rpm_dict) and ( diff --git a/litellm/tests/test_tpm_rpm_routing.py b/litellm/tests/test_tpm_rpm_routing copy.py similarity index 100% rename from litellm/tests/test_tpm_rpm_routing.py rename to litellm/tests/test_tpm_rpm_routing copy.py diff --git a/litellm/tests/test_tpm_rpm_routing_v2.py b/litellm/tests/test_tpm_rpm_routing_v2.py new file mode 100644 index 000000000..4a0256f6a --- /dev/null +++ b/litellm/tests/test_tpm_rpm_routing_v2.py @@ -0,0 +1,390 @@ +#### What this tests #### +# This tests the router's ability to pick deployment with lowest tpm using 'usage-based-routing-v2-v2' + +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +from litellm import Router +import litellm +from litellm.router_strategy.lowest_tpm_rpm_v2 import ( + LowestTPMLoggingHandler_v2 as LowestTPMLoggingHandler, +) +from litellm.utils import get_utc_datetime +from litellm.caching import DualCache + +### UNIT TESTS FOR TPM/RPM ROUTING ### + + +def test_tpm_rpm_updated(): + test_cache = DualCache() + model_list = [] + lowest_tpm_logger = LowestTPMLoggingHandler( + router_cache=test_cache, model_list=model_list + ) + model_group = "gpt-3.5-turbo" + deployment_id = "1234" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + end_time = time.time() + lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"]) + lowest_tpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + tpm_count_api_key = f"{deployment_id}:tpm:{current_minute}" + rpm_count_api_key = f"{deployment_id}:rpm:{current_minute}" + + print(f"tpm_count_api_key={tpm_count_api_key}") + assert response_obj["usage"]["total_tokens"] == test_cache.get_cache( + key=tpm_count_api_key + ) + assert 1 == test_cache.get_cache(key=rpm_count_api_key) + + +# test_tpm_rpm_updated() + + +def test_get_available_deployments(): + test_cache = DualCache() + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "azure/chatgpt-v-2"}, + "model_info": {"id": "1234"}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "azure/chatgpt-v-2"}, + "model_info": {"id": "5678"}, + }, + ] + lowest_tpm_logger = LowestTPMLoggingHandler( + router_cache=test_cache, model_list=model_list + ) + model_group = "gpt-3.5-turbo" + ## DEPLOYMENT 1 ## + deployment_id = "1234" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + end_time = time.time() + lowest_tpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + ## DEPLOYMENT 2 ## + deployment_id = "5678" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 20}} + end_time = time.time() + lowest_tpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + + ## CHECK WHAT'S SELECTED ## + assert ( + lowest_tpm_logger.get_available_deployments( + model_group=model_group, + healthy_deployments=model_list, + input=["Hello world"], + )["model_info"]["id"] + == "5678" + ) + + +# test_get_available_deployments() + + +def test_router_get_available_deployments(): + """ + Test if routers 'get_available_deployments' returns the least busy deployment + """ + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "rpm": 1440, + }, + "model_info": {"id": 1}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_EUROPE_API_KEY", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 2}, + }, + ] + router = Router( + model_list=model_list, + routing_strategy="usage-based-routing-v2", + set_verbose=False, + num_retries=3, + ) # type: ignore + + print(f"router id's: {router.get_model_ids()}") + ## DEPLOYMENT 1 ## + deployment_id = 1 + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "azure-model", + }, + "model_info": {"id": 1}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + end_time = time.time() + router.lowesttpm_logger_v2.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + ## DEPLOYMENT 2 ## + deployment_id = 2 + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "azure-model", + }, + "model_info": {"id": 2}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 20}} + end_time = time.time() + router.lowesttpm_logger_v2.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + + ## CHECK WHAT'S SELECTED ## + # print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model")) + assert ( + router.get_available_deployment(model="azure-model")["model_info"]["id"] == "2" + ) + + +# test_get_available_deployments() +# test_router_get_available_deployments() + + +def test_router_skip_rate_limited_deployments(): + """ + Test if routers 'get_available_deployments' raises No Models Available error if max tpm would be reached by message + """ + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "tpm": 1440, + }, + "model_info": {"id": 1}, + }, + ] + router = Router( + model_list=model_list, + routing_strategy="usage-based-routing-v2", + set_verbose=False, + num_retries=3, + ) # type: ignore + + ## DEPLOYMENT 1 ## + deployment_id = 1 + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "azure-model", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 1439}} + end_time = time.time() + router.lowesttpm_logger_v2.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + + ## CHECK WHAT'S SELECTED ## + # print(router.lowesttpm_logger_v2.get_available_deployments(model_group="azure-model")) + try: + router.get_available_deployment( + model="azure-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + pytest.fail(f"Should have raised No Models Available error") + except Exception as e: + print(f"An exception occurred! {str(e)}") + + +def test_single_deployment_tpm_zero(): + import litellm + import os + from datetime import datetime + + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + "tpm": 0, + }, + } + ] + + router = litellm.Router( + model_list=model_list, + routing_strategy="usage-based-routing-v2", + cache_responses=True, + ) + + model = "gpt-3.5-turbo" + messages = [{"content": "Hello, how are you?", "role": "user"}] + try: + router.get_available_deployment( + model=model, + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + pytest.fail(f"Should have raised No Models Available error") + except Exception as e: + print(f"it worked - {str(e)}! \n{traceback.format_exc()}") + + +@pytest.mark.asyncio +async def test_router_completion_streaming(): + messages = [ + {"role": "user", "content": "Hello, can you generate a 500 words poem?"} + ] + model = "azure-model" + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "rpm": 1440, + }, + "model_info": {"id": 1}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_EUROPE_API_KEY", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 2}, + }, + ] + router = Router( + model_list=model_list, + routing_strategy="usage-based-routing-v2", + set_verbose=False, + ) # type: ignore + + ### Make 3 calls, test if 3rd call goes to lowest tpm deployment + + ## CALL 1+2 + tasks = [] + response = None + final_response = None + for _ in range(2): + tasks.append(router.acompletion(model=model, messages=messages)) + response = await asyncio.gather(*tasks) + + if response is not None: + ## CALL 3 + await asyncio.sleep(1) # let the token update happen + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + picked_deployment = router.lowesttpm_logger_v2.get_available_deployments( + model_group=model, + healthy_deployments=router.healthy_deployments, + messages=messages, + ) + final_response = await router.acompletion(model=model, messages=messages) + print(f"min deployment id: {picked_deployment}") + tpm_key = f"{model}:tpm:{current_minute}" + rpm_key = f"{model}:rpm:{current_minute}" + + tpm_dict = router.cache.get_cache(key=tpm_key) + print(f"tpm_dict: {tpm_dict}") + rpm_dict = router.cache.get_cache(key=rpm_key) + print(f"rpm_dict: {rpm_dict}") + print(f"model id: {final_response._hidden_params['model_id']}") + assert ( + final_response._hidden_params["model_id"] + == picked_deployment["model_info"]["id"] + ) + + +# asyncio.run(test_router_completion_streaming()) + +""" +- Unit test for sync 'pre_call_checks' +- Unit test for async 'async_pre_call_checks' +"""