diff --git a/litellm/router.py b/litellm/router.py index d195c67b1..512a47a34 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -18,6 +18,7 @@ import inspect, concurrent from openai import AsyncOpenAI from collections import defaultdict from litellm.router_strategy.least_busy import LeastBusyLoggingHandler +from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.llms.custom_httpx.azure_dall_e_2 import ( CustomHTTPTransport, AsyncCustomHTTPTransport, @@ -67,6 +68,7 @@ class Router: num_retries: int = 0 tenacity = None leastbusy_logger: Optional[LeastBusyLoggingHandler] = None + lowesttpm_logger: Optional[LowestTPMLoggingHandler] = None def __init__( self, @@ -196,12 +198,14 @@ class Router: litellm.input_callback = [self.leastbusy_logger] # type: ignore if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.leastbusy_logger) # type: ignore - ## USAGE TRACKING ## - if isinstance(litellm.success_callback, list): - litellm.success_callback.append(self.deployment_callback) - else: - litellm.success_callback = [self.deployment_callback] + elif routing_strategy == "usage-based-routing": + self.lowesttpm_logger = LowestTPMLoggingHandler( + router_cache=self.cache, model_list=self.model_list + ) + if isinstance(litellm.callbacks, list): + litellm.callbacks.append(self.lowesttpm_logger) # type: ignore + ## COOLDOWNS ## if isinstance(litellm.failure_callback, list): litellm.failure_callback.append(self.deployment_callback_on_failure) else: @@ -1012,40 +1016,6 @@ class Router: ### HELPER FUNCTIONS - def deployment_callback( - self, - kwargs, # kwargs to completion - completion_response, # response from completion - start_time, - end_time, # start/end time - ): - """ - Function LiteLLM submits a callback to after a successful - completion. Purpose of this is to update TPM/RPM usage per model - """ - deployment_id = ( - kwargs.get("litellm_params", {}).get("model_info", {}).get("id", None) - ) - model_name = kwargs.get("model", None) # i.e. gpt35turbo - custom_llm_provider = kwargs.get("litellm_params", {}).get( - "custom_llm_provider", None - ) # i.e. azure - if custom_llm_provider: - model_name = f"{custom_llm_provider}/{model_name}" - if kwargs["stream"] is True: - if kwargs.get("complete_streaming_response"): - total_tokens = kwargs.get("complete_streaming_response")["usage"][ - "total_tokens" - ] - self._set_deployment_usage(deployment_id, total_tokens) - else: - total_tokens = completion_response["usage"]["total_tokens"] - self._set_deployment_usage(deployment_id, total_tokens) - - self.deployment_latency_map[model_name] = ( - end_time - start_time - ).total_seconds() - def deployment_callback_on_failure( self, kwargs, # kwargs to completion @@ -1180,109 +1150,6 @@ class Router: self.print_verbose(f"retrieve cooldown models: {cooldown_models}") return cooldown_models - def get_usage_based_available_deployment( - self, - model: str, - messages: Optional[List[Dict[str, str]]] = None, - input: Optional[Union[str, List]] = None, - ): - """ - Returns a deployment with the lowest TPM/RPM usage. - """ - # get list of potential deployments - potential_deployments = [] - for item in self.model_list: - if item["model_name"] == model: - potential_deployments.append(item) - - # get current call usage - token_count = 0 - if messages is not None: - token_count = litellm.token_counter(model=model, messages=messages) - elif input is not None: - if isinstance(input, List): - input_text = "".join(text for text in input) - else: - input_text = input - token_count = litellm.token_counter(model=model, text=input_text) - - # ----------------------- - # Find lowest used model - # ---------------------- - lowest_tpm = float("inf") - deployment = None - - # load model context map - models_context_map = litellm.model_cost - - # return deployment with lowest tpm usage - for item in potential_deployments: - model_id = item["model_info"].get("id") - item_tpm, item_rpm = self._get_deployment_usage(deployment_name=model_id) - - if item_tpm == 0: - return item - elif ( - "tpm" in item - and item_tpm + token_count > item["tpm"] - or "rpm" in item - and item_rpm + 1 >= item["rpm"] - ): # if user passed in tpm / rpm in the model_list - continue - elif item_tpm < lowest_tpm: - lowest_tpm = item_tpm - deployment = item - - # if none, raise exception - if deployment is None: - raise ValueError("No models available.") - - # return model - return deployment - - def _get_deployment_usage(self, deployment_name: str): - # ------------ - # Setup values - # ------------ - current_minute = datetime.now().strftime("%H-%M") - tpm_key = f"{deployment_name}:tpm:{current_minute}" - rpm_key = f"{deployment_name}:rpm:{current_minute}" - - # ------------ - # Return usage - # ------------ - tpm = self.cache.get_cache(key=tpm_key) or 0 - rpm = self.cache.get_cache(key=rpm_key) or 0 - - return int(tpm), int(rpm) - - def increment(self, key: str, increment_value: int): - # get value - cached_value = self.cache.get_cache(key=key) - # update value - try: - cached_value = cached_value + increment_value - except: - cached_value = increment_value - # save updated value - self.cache.set_cache( - value=cached_value, key=key, ttl=self.default_cache_time_seconds - ) - - def _set_deployment_usage(self, model_name: str, total_tokens: int): - # ------------ - # Setup values - # ------------ - current_minute = datetime.now().strftime("%H-%M") - tpm_key = f"{model_name}:tpm:{current_minute}" - rpm_key = f"{model_name}:rpm:{current_minute}" - - # ------------ - # Update usage - # ------------ - self.increment(tpm_key, total_tokens) - self.increment(rpm_key, 1) - def _start_health_check_thread(self): """ Starts a separate thread to perform health checks periodically. @@ -1733,10 +1600,11 @@ class Router: ) returned_item = self.weighted_shuffle_by_latency(items_with_latencies) return returned_item - elif self.routing_strategy == "usage-based-routing": - return self.get_usage_based_available_deployment( - model=model, messages=messages, input=input - ) + elif ( + self.routing_strategy == "usage-based-routing" + and self.lowesttpm_logger is not None + ): + return self.lowesttpm_logger.get_available_deployments(model_group=model) raise ValueError("No models available.") diff --git a/litellm/router_strategy/lowest_tpm_rpm.py b/litellm/router_strategy/lowest_tpm_rpm.py new file mode 100644 index 000000000..2e53aae88 --- /dev/null +++ b/litellm/router_strategy/lowest_tpm_rpm.py @@ -0,0 +1,169 @@ +#### What this does #### +# identifies lowest tpm deployment + +import dotenv, os, requests +from typing import Optional +from datetime import datetime + +dotenv.load_dotenv() # Loading env variables using dotenv +import traceback +from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger + + +class LowestTPMLoggingHandler(CustomLogger): + test_flag: bool = False + logged_success: int = 0 + logged_failure: int = 0 + default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour + + def __init__(self, router_cache: DualCache, model_list: list): + self.router_cache = router_cache + self.model_list = model_list + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update TPM/RPM usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + + total_tokens = response_obj["usage"]["total_tokens"] + + # ------------ + # Setup values + # ------------ + current_minute = datetime.now().strftime("%H-%M") + tpm_key = f"{model_group}:tpm:{current_minute}" + rpm_key = f"{model_group}:rpm:{current_minute}" + + # ------------ + # Update usage + # ------------ + + ## TPM + request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} + request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens + + self.router_cache.set_cache(key=tpm_key, value=request_count_dict) + + ## RPM + request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} + request_count_dict[id] = request_count_dict.get(id, 0) + 1 + + self.router_cache.set_cache(key=rpm_key, value=request_count_dict) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + pass + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update TPM/RPM usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + + total_tokens = response_obj["usage"]["total_tokens"] + + # ------------ + # Setup values + # ------------ + current_minute = datetime.now().strftime("%H-%M") + tpm_key = f"{model_group}:tpm:{current_minute}" + rpm_key = f"{model_group}:rpm:{current_minute}" + + # ------------ + # Update usage + # ------------ + # update cache + + ## TPM + request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} + request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens + + self.router_cache.set_cache(key=tpm_key, value=request_count_dict) + + ## RPM + request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} + request_count_dict[id] = request_count_dict.get(id, 0) + 1 + + self.router_cache.set_cache(key=rpm_key, value=request_count_dict) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + pass + + def get_available_deployments(self, model_group: str): + """ + Returns a deployment with the lowest TPM/RPM usage. + """ + # get list of potential deployments + current_minute = datetime.now().strftime("%H-%M") + tpm_key = f"{model_group}:tpm:{current_minute}" + rpm_key = f"{model_group}:rpm:{current_minute}" + + tpm_dict = self.router_cache.get_cache(key=tpm_key) + rpm_dict = self.router_cache.get_cache(key=rpm_key) + + # ----------------------- + # Find lowest used model + # ---------------------- + lowest_tpm = float("inf") + deployment = None + + for item, item_tpm in tpm_dict.items(): + ## get the item from model list + _deployment = None + for m in self.model_list: + if item == m["model_info"]["id"]: + _deployment = m + + if _deployment is None: + break + _deployment_tpm = ( + _deployment.get("tpm", None) + or _deployment.get("litellm_params", {}).get("tpm", None) + or _deployment.get("model_info", {}).get("tpm", None) + or float("inf") + ) + + _deployment_rpm = ( + _deployment.get("rpm", None) + or _deployment.get("litellm_params", {}).get("rpm", None) + or _deployment.get("model_info", {}).get("rpm", None) + or float("inf") + ) + + if item_tpm == 0: + return item + elif ( + item_tpm > _deployment_tpm or rpm_dict[item] + 1 >= _deployment_rpm + ): # if user passed in tpm / rpm in the model_list + continue + elif item_tpm < lowest_tpm: + lowest_tpm = item_tpm + deployment = _deployment + return deployment diff --git a/litellm/tests/test_least_busy_routing.py b/litellm/tests/test_least_busy_routing.py index bd0855ed7..0bc125fe5 100644 --- a/litellm/tests/test_least_busy_routing.py +++ b/litellm/tests/test_least_busy_routing.py @@ -1,13 +1,6 @@ #### What this tests #### # This tests the router's ability to identify the least busy deployment -# -# How is this achieved? -# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"} -# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic} -# - use litellm.success + failure callbacks to log when a request completed -# - in get_available_deployment, for a given model group name -> pick based on traffic - import sys, os, asyncio, time import traceback from dotenv import load_dotenv @@ -137,4 +130,4 @@ def test_router_get_available_deployments(): assert return_dict[3] == 100 -test_router_get_available_deployments() +# test_router_get_available_deployments() diff --git a/litellm/tests/test_tpm_rpm_routing.py b/litellm/tests/test_tpm_rpm_routing.py new file mode 100644 index 000000000..681ffc04f --- /dev/null +++ b/litellm/tests/test_tpm_rpm_routing.py @@ -0,0 +1,226 @@ +#### What this tests #### +# This tests the router's ability to pick deployment with lowest tpm + +import sys, os, asyncio, time +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +from litellm import Router +import litellm +from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler +from litellm.caching import DualCache + +### UNIT TESTS FOR TPM/RPM ROUTING ### + + +def test_tpm_rpm_updated(): + test_cache = DualCache() + model_list = [] + lowest_tpm_logger = LowestTPMLoggingHandler( + router_cache=test_cache, model_list=model_list + ) + model_group = "gpt-3.5-turbo" + deployment_id = "1234" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + end_time = time.time() + lowest_tpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + current_minute = datetime.now().strftime("%H-%M") + tpm_count_api_key = f"{model_group}:tpm:{current_minute}" + rpm_count_api_key = f"{model_group}:rpm:{current_minute}" + assert ( + response_obj["usage"]["total_tokens"] + == test_cache.get_cache(key=tpm_count_api_key)[deployment_id] + ) + assert 1 == test_cache.get_cache(key=rpm_count_api_key)[deployment_id] + + +# test_tpm_rpm_updated() + + +def test_get_available_deployments(): + test_cache = DualCache() + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "azure/chatgpt-v-2"}, + "model_info": {"id": "1234"}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "azure/chatgpt-v-2"}, + "model_info": {"id": "5678"}, + }, + ] + lowest_tpm_logger = LowestTPMLoggingHandler( + router_cache=test_cache, model_list=model_list + ) + model_group = "gpt-3.5-turbo" + ## DEPLOYMENT 1 ## + deployment_id = "1234" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + end_time = time.time() + lowest_tpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + ## DEPLOYMENT 2 ## + deployment_id = "5678" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 20}} + end_time = time.time() + lowest_tpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + + ## CHECK WHAT'S SELECTED ## + print(lowest_tpm_logger.get_available_deployments(model_group=model_group)) + assert ( + lowest_tpm_logger.get_available_deployments(model_group=model_group)[ + "model_info" + ]["id"] + == "5678" + ) + + +# test_get_available_deployments() + + +def test_router_get_available_deployments(): + """ + Test if routers 'get_available_deployments' returns the least busy deployment + """ + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "rpm": 1440, + }, + "model_info": {"id": 1}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_EUROPE_API_KEY", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 2}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_CANADA_API_KEY", + "api_base": "https://my-endpoint-canada-berri992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 3}, + }, + ] + router = Router( + model_list=model_list, + routing_strategy="usage-based-routing", + set_verbose=False, + num_retries=3, + ) # type: ignore + + ## DEPLOYMENT 1 ## + deployment_id = 1 + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "azure-model", + }, + "model_info": {"id": 1}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + end_time = time.time() + router.lowesttpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + ## DEPLOYMENT 2 ## + deployment_id = 2 + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "azure-model", + }, + "model_info": {"id": 2}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 20}} + end_time = time.time() + router.lowesttpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + + ## CHECK WHAT'S SELECTED ## + # print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model")) + print(router.get_available_deployment(model="azure-model")) + assert router.get_available_deployment(model="azure-model")["model_info"]["id"] == 2 + + +# test_get_available_deployments() + + +# test_router_get_available_deployments()