fix(least_busy.py): support consistent use of model id instead of deployment name

This commit is contained in:
Krrish Dholakia 2023-12-29 17:05:14 +05:30
parent 06e4b301b4
commit 678bbfa9be
3 changed files with 144 additions and 98 deletions

View file

@ -287,7 +287,7 @@ class Router:
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
# response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout)
response = await self.async_function_with_fallbacks(**kwargs)
return response
@ -1664,6 +1664,7 @@ class Router:
deployments = self.leastbusy_logger.get_available_deployments(
model_group=model
)
self.print_verbose(f"deployments in least-busy router: {deployments}")
# pick least busy deployment
min_traffic = float("inf")
min_deployment = None
@ -1671,14 +1672,19 @@ class Router:
if v < min_traffic:
min_traffic = v
min_deployment = k
self.print_verbose(f"min_deployment: {min_deployment};")
############## No Available Deployments passed, we do a random pick #################
if min_deployment is None:
min_deployment = random.choice(healthy_deployments)
############## Available Deployments passed, we find the relevant item #################
else:
## check if min deployment is a string, if so, cast it to int
if isinstance(min_deployment, str):
min_deployment = int(min_deployment)
for m in healthy_deployments:
if m["model_info"]["id"] == min_deployment:
return m
self.print_verbose(f"no healthy deployment with that id found!")
min_deployment = random.choice(healthy_deployments)
return min_deployment
elif self.routing_strategy == "simple-shuffle":

View file

@ -30,27 +30,20 @@ class LeastBusyLoggingHandler(CustomLogger):
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
deployment = kwargs["litellm_params"]["metadata"].get(
"deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if deployment is None or model_group is None or id is None:
if model_group is None or id is None:
return
# map deployment to id
self.mapping_deployment_to_id[deployment] = id
request_count_api_key = f"{model_group}_request_count"
# update cache
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = (
request_count_dict.get(deployment, 0) + 1
)
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
@ -62,13 +55,12 @@ class LeastBusyLoggingHandler(CustomLogger):
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
deployment = kwargs["litellm_params"]["metadata"].get(
"deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
if deployment is None or model_group is None:
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
request_count_api_key = f"{model_group}_request_count"
@ -76,7 +68,7 @@ class LeastBusyLoggingHandler(CustomLogger):
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = request_count_dict.get(deployment)
request_count_dict[id] = request_count_dict.get(id) - 1
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
@ -88,13 +80,11 @@ class LeastBusyLoggingHandler(CustomLogger):
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
deployment = kwargs["litellm_params"]["metadata"].get(
"deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
if deployment is None or model_group is None:
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
request_count_api_key = f"{model_group}_request_count"
@ -102,7 +92,7 @@ class LeastBusyLoggingHandler(CustomLogger):
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = request_count_dict.get(deployment)
request_count_dict[id] = request_count_dict.get(id) - 1
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
@ -111,11 +101,6 @@ class LeastBusyLoggingHandler(CustomLogger):
def get_available_deployments(self, model_group: str):
request_count_api_key = f"{model_group}_request_count"
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
return_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
# map deployment to id
return_dict = {}
for key, value in request_count_dict.items():
return_dict[self.mapping_deployment_to_id[key]] = value
return return_dict

View file

@ -1,79 +1,134 @@
# #### What this tests ####
# # This tests the router's ability to identify the least busy deployment
#### What this tests ####
# This tests the router's ability to identify the least busy deployment
# #
# # How is this achieved?
# # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# # - use litellm.success + failure callbacks to log when a request completed
# # - in get_available_deployment, for a given model group name -> pick based on traffic
#
# How is this achieved?
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# - use litellm.success + failure callbacks to log when a request completed
# - in get_available_deployment, for a given model group name -> pick based on traffic
# import sys, os, asyncio, time
# import traceback
# from dotenv import load_dotenv
import sys, os, asyncio, time
import traceback
from dotenv import load_dotenv
# load_dotenv()
# import os
load_dotenv()
import os
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import pytest
# from litellm import Router
# import litellm
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm import Router
import litellm
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.caching import DualCache
# async def test_least_busy_routing():
# model_list = [{
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-turbo",
# "api_key": "os.environ/AZURE_FRANCE_API_KEY",
# "api_base": "https://openai-france-1234.openai.azure.com",
# "rpm": 1440,
# }
# }, {
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-35-turbo",
# "api_key": "os.environ/AZURE_EUROPE_API_KEY",
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
# "rpm": 6
# }
# }, {
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-35-turbo",
# "api_key": "os.environ/AZURE_CANADA_API_KEY",
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
# "rpm": 6
# }
# }]
# router = Router(model_list=model_list,
# routing_strategy="least-busy",
# set_verbose=False,
# num_retries=3) # type: ignore
### UNIT TESTS FOR LEAST BUSY LOGGING ###
# async def call_azure_completion():
# try:
# response = await router.acompletion(
# model="azure-model",
# messages=[
# {
# "role": "user",
# "content": "hello this request will pass"
# }
# ]
# )
# print("\n response", response)
# return response
# except:
# return None
# n = 1000
# start_time = time.time()
# tasks = [call_azure_completion() for _ in range(n)]
# chat_completions = await asyncio.gather(*tasks)
# successful_completions = [c for c in chat_completions if c is not None]
# print(n, time.time() - start_time, len(successful_completions))
def test_model_added():
test_cache = DualCache()
least_busy_logger = LeastBusyLoggingHandler(router_cache=test_cache)
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": "1234"},
}
}
least_busy_logger.log_pre_api_call(model="test", messages=[], kwargs=kwargs)
request_count_api_key = f"gpt-3.5-turbo_request_count"
assert test_cache.get_cache(key=request_count_api_key) is not None
# asyncio.run(test_least_busy_routing())
def test_get_available_deployments():
test_cache = DualCache()
least_busy_logger = LeastBusyLoggingHandler(router_cache=test_cache)
model_group = "gpt-3.5-turbo"
deployment = "azure/chatgpt-v-2"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": model_group,
"deployment": deployment,
},
"model_info": {"id": "1234"},
}
}
least_busy_logger.log_pre_api_call(model="test", messages=[], kwargs=kwargs)
request_count_api_key = f"{model_group}_request_count"
assert test_cache.get_cache(key=request_count_api_key) is not None
print(least_busy_logger.get_available_deployments(model_group=model_group))
# test_get_available_deployments()
def test_router_get_available_deployments():
"""
Tests if 'get_available_deployments' returns the least busy deployment
"""
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_CANADA_API_KEY",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 3},
},
]
router = Router(
model_list=model_list,
routing_strategy="least-busy",
set_verbose=False,
num_retries=3,
) # type: ignore
model_group = "azure-model"
deployment = "azure/chatgpt-v-2"
request_count_dict = {1: 10, 2: 54, 3: 100}
cache_key = f"{model_group}_request_count"
router.cache.set_cache(key=cache_key, value=request_count_dict)
deployment = router.get_available_deployment(model=model_group, messages=None)
print(f"deployment: {deployment}")
assert deployment["model_info"]["id"] == 1
## run router completion - assert that the least-busy deployment was incremented
router.completion(
model=model_group,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
least_busy_dict = router.cache.get_cache(key=cache_key)
assert least_busy_dict[1] == 11
# test_router_get_available_deployments()