From f288b12411b7994fb331390b729cf9c46dd73a07 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 10 Jan 2024 20:52:01 +0530 Subject: [PATCH] fix(lowest_latency.py): add back tpm/rpm checks, configurable time window --- litellm/router.py | 25 ++- litellm/router_strategy/lowest_latency.py | 205 ++++++++++++++++--- litellm/tests/test_lowest_latency_routing.py | 133 +++++++++++- 3 files changed, 312 insertions(+), 51 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 311afeb446..f635555098 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -105,6 +105,7 @@ class Router: "usage-based-routing", "latency-based-routing", ] = "simple-shuffle", + routing_strategy_args: dict = {}, # just for latency-based routing ) -> None: self.set_verbose = set_verbose self.deployment_names: List = ( @@ -217,7 +218,9 @@ class Router: litellm.callbacks.append(self.lowesttpm_logger) # type: ignore elif routing_strategy == "latency-based-routing": self.lowestlatency_logger = LowestLatencyLoggingHandler( - router_cache=self.cache, model_list=self.model_list + router_cache=self.cache, + model_list=self.model_list, + routing_args=routing_strategy_args, ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowestlatency_logger) # type: ignore @@ -1427,9 +1430,8 @@ class Router: http_client=httpx.AsyncClient( transport=AsyncCustomHTTPTransport(), limits=httpx.Limits( - max_connections=1000, - max_keepalive_connections=100 - ) + max_connections=1000, max_keepalive_connections=100 + ), ), # type: ignore ) self.cache.set_cache( @@ -1449,9 +1451,8 @@ class Router: http_client=httpx.Client( transport=CustomHTTPTransport(), limits=httpx.Limits( - max_connections=1000, - max_keepalive_connections=100 - ) + max_connections=1000, max_keepalive_connections=100 + ), ), # type: ignore ) self.cache.set_cache( @@ -1471,10 +1472,9 @@ class Router: max_retries=max_retries, http_client=httpx.AsyncClient( limits=httpx.Limits( - max_connections=1000, - max_keepalive_connections=100 + max_connections=1000, max_keepalive_connections=100 ) - ) + ), ) self.cache.set_cache( key=cache_key, @@ -1492,10 +1492,9 @@ class Router: max_retries=max_retries, http_client=httpx.Client( limits=httpx.Limits( - max_connections=1000, - max_keepalive_connections=100 + max_connections=1000, max_keepalive_connections=100 ) - ) + ), ) self.cache.set_cache( key=cache_key, diff --git a/litellm/router_strategy/lowest_latency.py b/litellm/router_strategy/lowest_latency.py index 43e28a8b3a..3f8cb513b4 100644 --- a/litellm/router_strategy/lowest_latency.py +++ b/litellm/router_strategy/lowest_latency.py @@ -1,25 +1,46 @@ #### What this does #### # picks based on response time (for streaming, this is time to first token) - +from pydantic import BaseModel, Extra, Field, root_validator import dotenv, os, requests, random -from typing import Optional +from typing import Optional, Union, List, Dict from datetime import datetime, timedelta dotenv.load_dotenv() # Loading env variables using dotenv import traceback from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger +from litellm import ModelResponse +from litellm import token_counter + + +class LiteLLMBase(BaseModel): + """ + Implements default functions, all pydantic objects should have. + """ + + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() + + +class RoutingArgs(LiteLLMBase): + ttl: int = 1 * 60 * 60 # 1 hour class LowestLatencyLoggingHandler(CustomLogger): test_flag: bool = False logged_success: int = 0 logged_failure: int = 0 - default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour - def __init__(self, router_cache: DualCache, model_list: list): + def __init__( + self, router_cache: DualCache, model_list: list, routing_args: dict = {} + ): self.router_cache = router_cache self.model_list = model_list + self.routing_args = RoutingArgs(**routing_args) def log_success_event(self, kwargs, response_obj, start_time, end_time): try: @@ -37,25 +58,64 @@ class LowestLatencyLoggingHandler(CustomLogger): if model_group is None or id is None: return - response_ms = end_time - start_time - # ------------ # Setup values # ------------ - latency_key = f"{model_group}_latency_map" + """ + { + {model_group}_map: { + id: { + "latency": [..] + f"{date:hour:minute}" : {"tpm": 34, "rpm": 3} + } + } + } + """ + latency_key = f"{model_group}_map" + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + response_ms: timedelta = end_time - start_time + + final_value = response_ms + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + completion_tokens = response_obj.usage.completion_tokens + total_tokens = response_obj.usage.total_tokens + final_value = float(completion_tokens / response_ms.total_seconds()) # ------------ # Update usage # ------------ - ## Latency request_count_dict = self.router_cache.get_cache(key=latency_key) or {} - if id in request_count_dict and isinstance(request_count_dict[id], list): - request_count_dict[id] = request_count_dict[id].append(response_ms) - else: - request_count_dict[id] = [response_ms] - self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window + if id not in request_count_dict: + request_count_dict[id] = {} + + ## Latency + request_count_dict[id].setdefault("latency", []).append(final_value) + + if precise_minute not in request_count_dict[id]: + request_count_dict[id][precise_minute] = {} + + ## TPM + request_count_dict[id][precise_minute]["tpm"] = ( + request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens + ) + + ## RPM + request_count_dict[id][precise_minute]["rpm"] = ( + request_count_dict[id][precise_minute].get("rpm", 0) + 1 + ) + + self.router_cache.set_cache( + key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl + ) # reset map within window ### TESTING ### if self.test_flag: @@ -80,25 +140,64 @@ class LowestLatencyLoggingHandler(CustomLogger): if model_group is None or id is None: return - response_ms = end_time - start_time - # ------------ # Setup values # ------------ - latency_key = f"{model_group}_latency_map" + """ + { + {model_group}_map: { + id: { + "latency": [..] + f"{date:hour:minute}" : {"tpm": 34, "rpm": 3} + } + } + } + """ + latency_key = f"{model_group}_map" + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + + response_ms: timedelta = end_time - start_time + + final_value = response_ms + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + completion_tokens = response_obj.usage.completion_tokens + total_tokens = response_obj.usage.total_tokens + final_value = float(completion_tokens / response_ms.total_seconds()) # ------------ # Update usage # ------------ - ## Latency request_count_dict = self.router_cache.get_cache(key=latency_key) or {} - if id in request_count_dict and isinstance(request_count_dict[id], list): - request_count_dict[id] = request_count_dict[id] + [response_ms] - else: - request_count_dict[id] = [response_ms] - self.router_cache.set_cache(key=latency_key, value=request_count_dict, ttl=self.default_cache_time_seconds) # reset map within window + if id not in request_count_dict: + request_count_dict[id] = {} + + ## Latency + request_count_dict[id].setdefault("latency", []).append(final_value) + + if precise_minute not in request_count_dict[id]: + request_count_dict[id][precise_minute] = {} + + ## TPM + request_count_dict[id][precise_minute]["tpm"] = ( + request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens + ) + + ## RPM + request_count_dict[id][precise_minute]["rpm"] = ( + request_count_dict[id][precise_minute].get("rpm", 0) + 1 + ) + + self.router_cache.set_cache( + key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl + ) # reset map within window ### TESTING ### if self.test_flag: @@ -107,12 +206,18 @@ class LowestLatencyLoggingHandler(CustomLogger): traceback.print_exc() pass - def get_available_deployments(self, model_group: str, healthy_deployments: list): + def get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + ): """ Returns a deployment with the lowest latency """ # get list of potential deployments - latency_key = f"{model_group}_latency_map" + latency_key = f"{model_group}_map" request_count_dict = self.router_cache.get_cache(key=latency_key) or {} @@ -120,6 +225,12 @@ class LowestLatencyLoggingHandler(CustomLogger): # Find lowest used model # ---------------------- lowest_latency = float("inf") + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + deployment = None if request_count_dict is None: # base case @@ -129,9 +240,17 @@ class LowestLatencyLoggingHandler(CustomLogger): for d in healthy_deployments: ## if healthy deployment not yet used if d["model_info"]["id"] not in all_deployments: - all_deployments[d["model_info"]["id"]] = [0] + all_deployments[d["model_info"]["id"]] = { + "latency": [0], + precise_minute: {"tpm": 0, "rpm": 0}, + } - for item, item_latency in all_deployments.items(): + try: + input_tokens = token_counter(messages=messages, text=input) + except: + input_tokens = 0 + + for item, item_map in all_deployments.items(): ## get the item from model list _deployment = None for m in healthy_deployments: @@ -140,18 +259,38 @@ class LowestLatencyLoggingHandler(CustomLogger): if _deployment is None: continue # skip to next one - - # get average latency - total = 0.0 + + _deployment_tpm = ( + _deployment.get("tpm", None) + or _deployment.get("litellm_params", {}).get("tpm", None) + or _deployment.get("model_info", {}).get("tpm", None) + or float("inf") + ) + + _deployment_rpm = ( + _deployment.get("rpm", None) + or _deployment.get("litellm_params", {}).get("rpm", None) + or _deployment.get("model_info", {}).get("rpm", None) + or float("inf") + ) + item_latency = item_map.get("latency", []) + item_rpm = item_map.get(precise_minute, {}).get("rpm", 0) + item_tpm = item_map.get(precise_minute, {}).get("tpm", 0) + + # get average latency + total: float = 0.0 for _call_latency in item_latency: - if isinstance(_call_latency, timedelta): - total += float(_call_latency.total_seconds()) - elif isinstance(_call_latency, float): + if isinstance(_call_latency, float): total += _call_latency - item_latency = total/len(item_latency) + item_latency = total / len(item_latency) if item_latency == 0: deployment = _deployment break + elif ( + item_tpm + input_tokens > _deployment_tpm + or item_rpm + 1 > _deployment_rpm + ): # if user passed in tpm / rpm in the model_list + continue elif item_latency < lowest_latency: lowest_latency = item_latency deployment = _deployment diff --git a/litellm/tests/test_lowest_latency_routing.py b/litellm/tests/test_lowest_latency_routing.py index 6805bfa588..c9b1e7972c 100644 --- a/litellm/tests/test_lowest_latency_routing.py +++ b/litellm/tests/test_lowest_latency_routing.py @@ -48,13 +48,55 @@ def test_latency_updated(): start_time=start_time, end_time=end_time, ) - latency_key = f"{model_group}_latency_map" - assert end_time - start_time == test_cache.get_cache(key=latency_key)[deployment_id][0] + latency_key = f"{model_group}_map" + assert ( + end_time - start_time + == test_cache.get_cache(key=latency_key)[deployment_id]["latency"][0] + ) # test_tpm_rpm_updated() +def test_latency_updated_custom_ttl(): + """ + Invalidate the cached request. + + Test that the cache is empty + """ + test_cache = DualCache() + model_list = [] + cache_time = 3 + lowest_latency_logger = LowestLatencyLoggingHandler( + router_cache=test_cache, model_list=model_list, routing_args={"ttl": cache_time} + ) + model_group = "gpt-3.5-turbo" + deployment_id = "1234" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + time.sleep(5) + end_time = time.time() + lowest_latency_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + latency_key = f"{model_group}_map" + assert isinstance(test_cache.get_cache(key=latency_key), dict) + time.sleep(cache_time) + assert test_cache.get_cache(key=latency_key) is None + + def test_get_available_deployments(): test_cache = DualCache() model_list = [ @@ -133,6 +175,90 @@ def test_get_available_deployments(): # test_get_available_deployments() +def test_get_available_endpoints_tpm_rpm_check(): + """ + Pass in list of 2 valid models + + Update cache with 1 model clearly being at tpm/rpm limit + + assert that only the valid model is returned + """ + test_cache = DualCache() + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "azure/chatgpt-v-2"}, + "model_info": {"id": "1234", "rpm": 10}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "azure/chatgpt-v-2"}, + "model_info": {"id": "5678", "rpm": 3}, + }, + ] + lowest_latency_logger = LowestLatencyLoggingHandler( + router_cache=test_cache, model_list=model_list + ) + model_group = "gpt-3.5-turbo" + ## DEPLOYMENT 1 ## + deployment_id = "1234" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + for _ in range(3): + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + time.sleep(0.05) + end_time = time.time() + lowest_latency_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + ## DEPLOYMENT 2 ## + deployment_id = "5678" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + for _ in range(3): + start_time = time.time() + response_obj = {"usage": {"total_tokens": 20}} + time.sleep(2) + end_time = time.time() + lowest_latency_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + + ## CHECK WHAT'S SELECTED ## + print( + lowest_latency_logger.get_available_deployments( + model_group=model_group, healthy_deployments=model_list + ) + ) + assert ( + lowest_latency_logger.get_available_deployments( + model_group=model_group, healthy_deployments=model_list + )["model_info"]["id"] + == "1234" + ) + + def test_router_get_available_deployments(): """ Test if routers 'get_available_deployments' returns the fastest deployment @@ -213,9 +339,6 @@ def test_router_get_available_deployments(): assert router.get_available_deployment(model="azure-model")["model_info"]["id"] == 2 -# test_get_available_deployments() - - # test_router_get_available_deployments()