From f2d0d5584a53daa99ea2aefdca5ec65134895f4f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Dec 2023 17:25:40 +0530 Subject: [PATCH] fix(router.py): fix latency based routing --- litellm/router.py | 30 +- litellm/router_strategy/lowest_latency.py | 146 ++++++++++ litellm/tests/test_lowest_latency_routing.py | 279 +++++++++++++++++++ litellm/tests/test_router_init.py | 2 +- 4 files changed, 443 insertions(+), 14 deletions(-) create mode 100644 litellm/router_strategy/lowest_latency.py create mode 100644 litellm/tests/test_lowest_latency_routing.py diff --git a/litellm/router.py b/litellm/router.py index 4e953c740..ffc2bff85 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -19,6 +19,7 @@ from openai import AsyncOpenAI from collections import defaultdict from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler +from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.llms.custom_httpx.azure_dall_e_2 import ( CustomHTTPTransport, AsyncCustomHTTPTransport, @@ -211,7 +212,12 @@ class Router: ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowesttpm_logger) # type: ignore - + elif routing_strategy == "latency-based-routing": + self.lowestlatency_logger = LowestLatencyLoggingHandler( + router_cache=self.cache, model_list=self.model_list + ) + if isinstance(litellm.callbacks, list): + litellm.callbacks.append(self.lowestlatency_logger) # type: ignore ## COOLDOWNS ## if isinstance(litellm.failure_callback, list): litellm.failure_callback.append(self.deployment_callback_on_failure) @@ -1733,18 +1739,16 @@ class Router: ############## No RPM/TPM passed, we do a random pick ################# item = random.choice(healthy_deployments) return item or item[0] - elif self.routing_strategy == "latency-based-routing": - returned_item = None - lowest_latency = float("inf") - ### shuffles with priority for lowest latency - # items_with_latencies = [('A', 10), ('B', 20), ('C', 30), ('D', 40)] - items_with_latencies = [] - for item in healthy_deployments: - items_with_latencies.append( - (item, self.deployment_latency_map[item["litellm_params"]["model"]]) - ) - returned_item = self.weighted_shuffle_by_latency(items_with_latencies) - return returned_item + elif ( + self.routing_strategy == "latency-based-routing" + and self.lowestlatency_logger is not None + ): + min_deployment = self.lowestlatency_logger.get_available_deployments( + model_group=model, healthy_deployments=healthy_deployments + ) + if min_deployment is None: + min_deployment = random.choice(healthy_deployments) + return min_deployment elif ( self.routing_strategy == "usage-based-routing" and self.lowesttpm_logger is not None diff --git a/litellm/router_strategy/lowest_latency.py b/litellm/router_strategy/lowest_latency.py new file mode 100644 index 000000000..7f1de8bbd --- /dev/null +++ b/litellm/router_strategy/lowest_latency.py @@ -0,0 +1,146 @@ +#### What this does #### +# picks based on response time (for streaming, this is time to first token) + +import dotenv, os, requests, random +from typing import Optional +from datetime import datetime + +dotenv.load_dotenv() # Loading env variables using dotenv +import traceback +from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger + + +class LowestLatencyLoggingHandler(CustomLogger): + test_flag: bool = False + logged_success: int = 0 + logged_failure: int = 0 + default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour + + def __init__(self, router_cache: DualCache, model_list: list): + self.router_cache = router_cache + self.model_list = model_list + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update latency usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + + response_ms = end_time - start_time + + # ------------ + # Setup values + # ------------ + latency_key = f"{model_group}_latency_map" + + # ------------ + # Update usage + # ------------ + + ## Latency + request_count_dict = self.router_cache.get_cache(key=latency_key) or {} + request_count_dict[id] = response_ms + + self.router_cache.set_cache(key=latency_key, value=request_count_dict) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + traceback.print_exc() + pass + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update latency usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + + response_ms = end_time - start_time + + # ------------ + # Setup values + # ------------ + latency_key = f"{model_group}_latency_map" + + # ------------ + # Update usage + # ------------ + + ## Latency + request_count_dict = self.router_cache.get_cache(key=latency_key) or {} + request_count_dict[id] = response_ms + + self.router_cache.set_cache(key=latency_key, value=request_count_dict) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + traceback.print_exc() + pass + + def get_available_deployments(self, model_group: str, healthy_deployments: list): + """ + Returns a deployment with the lowest latency + """ + # get list of potential deployments + latency_key = f"{model_group}_latency_map" + + request_count_dict = self.router_cache.get_cache(key=latency_key) or {} + + # ----------------------- + # Find lowest used model + # ---------------------- + lowest_latency = float("inf") + deployment = None + + if request_count_dict is None: # base case + return + + all_deployments = request_count_dict + for d in healthy_deployments: + ## if healthy deployment not yet used + if d["model_info"]["id"] not in all_deployments: + all_deployments[d["model_info"]["id"]] = 0 + + for item, item_latency in all_deployments.items(): + ## get the item from model list + _deployment = None + for m in healthy_deployments: + if item == m["model_info"]["id"]: + _deployment = m + + if _deployment is None: + continue # skip to next one + + if item_latency == 0: + deployment = _deployment + break + elif item_latency < lowest_latency: + lowest_latency = item_latency + deployment = _deployment + if deployment is None: + deployment = random.choice(healthy_deployments) + return deployment diff --git a/litellm/tests/test_lowest_latency_routing.py b/litellm/tests/test_lowest_latency_routing.py new file mode 100644 index 000000000..a1a501a34 --- /dev/null +++ b/litellm/tests/test_lowest_latency_routing.py @@ -0,0 +1,279 @@ +#### What this tests #### +# This tests the router's ability to pick deployment with lowest latency + +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +from litellm import Router +import litellm +from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler +from litellm.caching import DualCache + +### UNIT TESTS FOR LATENCY ROUTING ### + + +def test_latency_updated(): + test_cache = DualCache() + model_list = [] + lowest_latency_logger = LowestLatencyLoggingHandler( + router_cache=test_cache, model_list=model_list + ) + model_group = "gpt-3.5-turbo" + deployment_id = "1234" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + time.sleep(5) + end_time = time.time() + lowest_latency_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + latency_key = f"{model_group}_latency_map" + assert end_time - start_time == test_cache.get_cache(key=latency_key)[deployment_id] + + +# test_tpm_rpm_updated() + + +def test_get_available_deployments(): + test_cache = DualCache() + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "azure/chatgpt-v-2"}, + "model_info": {"id": "1234"}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "azure/chatgpt-v-2"}, + "model_info": {"id": "5678"}, + }, + ] + lowest_tpm_logger = LowestLatencyLoggingHandler( + router_cache=test_cache, model_list=model_list + ) + model_group = "gpt-3.5-turbo" + ## DEPLOYMENT 1 ## + deployment_id = "1234" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 50}} + end_time = time.time() + lowest_tpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + ## DEPLOYMENT 2 ## + deployment_id = "5678" + kwargs = { + "litellm_params": { + "metadata": { + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + }, + "model_info": {"id": deployment_id}, + } + } + start_time = time.time() + response_obj = {"usage": {"total_tokens": 20}} + end_time = time.time() + lowest_tpm_logger.log_success_event( + response_obj=response_obj, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + ) + + ## CHECK WHAT'S SELECTED ## + print( + lowest_tpm_logger.get_available_deployments( + model_group=model_group, healthy_deployments=model_list + ) + ) + assert ( + lowest_tpm_logger.get_available_deployments( + model_group=model_group, healthy_deployments=model_list + )["model_info"]["id"] + == "5678" + ) + + +# test_get_available_deployments() + + +# def test_router_get_available_deployments(): +# """ +# Test if routers 'get_available_deployments' returns the least busy deployment +# """ +# model_list = [ +# { +# "model_name": "azure-model", +# "litellm_params": { +# "model": "azure/gpt-turbo", +# "api_key": "os.environ/AZURE_FRANCE_API_KEY", +# "api_base": "https://openai-france-1234.openai.azure.com", +# "rpm": 1440, +# }, +# "model_info": {"id": 1}, +# }, +# { +# "model_name": "azure-model", +# "litellm_params": { +# "model": "azure/gpt-35-turbo", +# "api_key": "os.environ/AZURE_EUROPE_API_KEY", +# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", +# "rpm": 6, +# }, +# "model_info": {"id": 2}, +# }, +# ] +# router = Router( +# model_list=model_list, +# routing_strategy="usage-based-routing", +# set_verbose=False, +# num_retries=3, +# ) # type: ignore + +# ## DEPLOYMENT 1 ## +# deployment_id = 1 +# kwargs = { +# "litellm_params": { +# "metadata": { +# "model_group": "azure-model", +# }, +# "model_info": {"id": 1}, +# } +# } +# start_time = time.time() +# response_obj = {"usage": {"total_tokens": 50}} +# end_time = time.time() +# router.lowesttpm_logger.log_success_event( +# response_obj=response_obj, +# kwargs=kwargs, +# start_time=start_time, +# end_time=end_time, +# ) +# ## DEPLOYMENT 2 ## +# deployment_id = 2 +# kwargs = { +# "litellm_params": { +# "metadata": { +# "model_group": "azure-model", +# }, +# "model_info": {"id": 2}, +# } +# } +# start_time = time.time() +# response_obj = {"usage": {"total_tokens": 20}} +# end_time = time.time() +# router.lowesttpm_logger.log_success_event( +# response_obj=response_obj, +# kwargs=kwargs, +# start_time=start_time, +# end_time=end_time, +# ) + +# ## CHECK WHAT'S SELECTED ## +# # print(router.lowesttpm_logger.get_available_deployments(model_group="azure-model")) +# print(router.get_available_deployment(model="azure-model")) +# assert router.get_available_deployment(model="azure-model")["model_info"]["id"] == 2 + + +# # test_get_available_deployments() + + +# # test_router_get_available_deployments() + + +# @pytest.mark.asyncio +# async def test_router_completion_streaming(): +# messages = [ +# {"role": "user", "content": "Hello, can you generate a 500 words poem?"} +# ] +# model = "azure-model" +# model_list = [ +# { +# "model_name": "azure-model", +# "litellm_params": { +# "model": "azure/gpt-turbo", +# "api_key": "os.environ/AZURE_FRANCE_API_KEY", +# "api_base": "https://openai-france-1234.openai.azure.com", +# "rpm": 1440, +# }, +# "model_info": {"id": 1}, +# }, +# { +# "model_name": "azure-model", +# "litellm_params": { +# "model": "azure/gpt-35-turbo", +# "api_key": "os.environ/AZURE_EUROPE_API_KEY", +# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", +# "rpm": 6, +# }, +# "model_info": {"id": 2}, +# }, +# ] +# router = Router( +# model_list=model_list, +# routing_strategy="usage-based-routing", +# set_verbose=False, +# num_retries=3, +# ) # type: ignore + +# ### Make 3 calls, test if 3rd call goes to lowest tpm deployment + +# ## CALL 1+2 +# tasks = [] +# response = None +# final_response = None +# for _ in range(2): +# tasks.append(router.acompletion(model=model, messages=messages)) +# response = await asyncio.gather(*tasks) + +# if response is not None: +# ## CALL 3 +# await asyncio.sleep(1) # let the token update happen +# current_minute = datetime.now().strftime("%H-%M") +# picked_deployment = router.lowesttpm_logger.get_available_deployments( +# model_group=model, healthy_deployments=router.healthy_deployments +# ) +# final_response = await router.acompletion(model=model, messages=messages) +# print(f"min deployment id: {picked_deployment}") +# print(f"model id: {final_response._hidden_params['model_id']}") +# assert ( +# final_response._hidden_params["model_id"] +# == picked_deployment["model_info"]["id"] +# ) + + +# # asyncio.run(test_router_completion_streaming()) diff --git a/litellm/tests/test_router_init.py b/litellm/tests/test_router_init.py index 581506d13..bd70a33d1 100644 --- a/litellm/tests/test_router_init.py +++ b/litellm/tests/test_router_init.py @@ -199,4 +199,4 @@ def test_stream_timeouts_router(): ) -test_stream_timeouts_router() +# test_stream_timeouts_router()