forked from phoenix/litellm-mirror
fix(least_busy.py): support consistent use of model id instead of deployment name
This commit is contained in:
parent
06e4b301b4
commit
678bbfa9be
3 changed files with 144 additions and 98 deletions
|
@ -287,7 +287,7 @@ class Router:
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
timeout = kwargs.get("request_timeout", self.timeout)
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
# response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout)
|
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -1664,6 +1664,7 @@ class Router:
|
||||||
deployments = self.leastbusy_logger.get_available_deployments(
|
deployments = self.leastbusy_logger.get_available_deployments(
|
||||||
model_group=model
|
model_group=model
|
||||||
)
|
)
|
||||||
|
self.print_verbose(f"deployments in least-busy router: {deployments}")
|
||||||
# pick least busy deployment
|
# pick least busy deployment
|
||||||
min_traffic = float("inf")
|
min_traffic = float("inf")
|
||||||
min_deployment = None
|
min_deployment = None
|
||||||
|
@ -1671,14 +1672,19 @@ class Router:
|
||||||
if v < min_traffic:
|
if v < min_traffic:
|
||||||
min_traffic = v
|
min_traffic = v
|
||||||
min_deployment = k
|
min_deployment = k
|
||||||
|
self.print_verbose(f"min_deployment: {min_deployment};")
|
||||||
############## No Available Deployments passed, we do a random pick #################
|
############## No Available Deployments passed, we do a random pick #################
|
||||||
if min_deployment is None:
|
if min_deployment is None:
|
||||||
min_deployment = random.choice(healthy_deployments)
|
min_deployment = random.choice(healthy_deployments)
|
||||||
############## Available Deployments passed, we find the relevant item #################
|
############## Available Deployments passed, we find the relevant item #################
|
||||||
else:
|
else:
|
||||||
|
## check if min deployment is a string, if so, cast it to int
|
||||||
|
if isinstance(min_deployment, str):
|
||||||
|
min_deployment = int(min_deployment)
|
||||||
for m in healthy_deployments:
|
for m in healthy_deployments:
|
||||||
if m["model_info"]["id"] == min_deployment:
|
if m["model_info"]["id"] == min_deployment:
|
||||||
return m
|
return m
|
||||||
|
self.print_verbose(f"no healthy deployment with that id found!")
|
||||||
min_deployment = random.choice(healthy_deployments)
|
min_deployment = random.choice(healthy_deployments)
|
||||||
return min_deployment
|
return min_deployment
|
||||||
elif self.routing_strategy == "simple-shuffle":
|
elif self.routing_strategy == "simple-shuffle":
|
||||||
|
|
|
@ -30,27 +30,20 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
if kwargs["litellm_params"].get("metadata") is None:
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
deployment = kwargs["litellm_params"]["metadata"].get(
|
|
||||||
"deployment", None
|
|
||||||
)
|
|
||||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
"model_group", None
|
"model_group", None
|
||||||
)
|
)
|
||||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
if deployment is None or model_group is None or id is None:
|
if model_group is None or id is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# map deployment to id
|
|
||||||
self.mapping_deployment_to_id[deployment] = id
|
|
||||||
|
|
||||||
request_count_api_key = f"{model_group}_request_count"
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
# update cache
|
# update cache
|
||||||
request_count_dict = (
|
request_count_dict = (
|
||||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
)
|
)
|
||||||
request_count_dict[deployment] = (
|
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||||
request_count_dict.get(deployment, 0) + 1
|
|
||||||
)
|
|
||||||
self.router_cache.set_cache(
|
self.router_cache.set_cache(
|
||||||
key=request_count_api_key, value=request_count_dict
|
key=request_count_api_key, value=request_count_dict
|
||||||
)
|
)
|
||||||
|
@ -62,13 +55,12 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
if kwargs["litellm_params"].get("metadata") is None:
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
deployment = kwargs["litellm_params"]["metadata"].get(
|
|
||||||
"deployment", None
|
|
||||||
)
|
|
||||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
"model_group", None
|
"model_group", None
|
||||||
)
|
)
|
||||||
if deployment is None or model_group is None:
|
|
||||||
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
request_count_api_key = f"{model_group}_request_count"
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
|
@ -76,7 +68,7 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
request_count_dict = (
|
request_count_dict = (
|
||||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
)
|
)
|
||||||
request_count_dict[deployment] = request_count_dict.get(deployment)
|
request_count_dict[id] = request_count_dict.get(id) - 1
|
||||||
self.router_cache.set_cache(
|
self.router_cache.set_cache(
|
||||||
key=request_count_api_key, value=request_count_dict
|
key=request_count_api_key, value=request_count_dict
|
||||||
)
|
)
|
||||||
|
@ -88,13 +80,11 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
if kwargs["litellm_params"].get("metadata") is None:
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
deployment = kwargs["litellm_params"]["metadata"].get(
|
|
||||||
"deployment", None
|
|
||||||
)
|
|
||||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||||
"model_group", None
|
"model_group", None
|
||||||
)
|
)
|
||||||
if deployment is None or model_group is None:
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
|
if model_group is None or id is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
request_count_api_key = f"{model_group}_request_count"
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
|
@ -102,7 +92,7 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
request_count_dict = (
|
request_count_dict = (
|
||||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
)
|
)
|
||||||
request_count_dict[deployment] = request_count_dict.get(deployment)
|
request_count_dict[id] = request_count_dict.get(id) - 1
|
||||||
self.router_cache.set_cache(
|
self.router_cache.set_cache(
|
||||||
key=request_count_api_key, value=request_count_dict
|
key=request_count_api_key, value=request_count_dict
|
||||||
)
|
)
|
||||||
|
@ -111,11 +101,6 @@ class LeastBusyLoggingHandler(CustomLogger):
|
||||||
|
|
||||||
def get_available_deployments(self, model_group: str):
|
def get_available_deployments(self, model_group: str):
|
||||||
request_count_api_key = f"{model_group}_request_count"
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
request_count_dict = (
|
return_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
|
||||||
)
|
|
||||||
# map deployment to id
|
# map deployment to id
|
||||||
return_dict = {}
|
|
||||||
for key, value in request_count_dict.items():
|
|
||||||
return_dict[self.mapping_deployment_to_id[key]] = value
|
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
|
@ -1,79 +1,134 @@
|
||||||
# #### What this tests ####
|
#### What this tests ####
|
||||||
# # This tests the router's ability to identify the least busy deployment
|
# This tests the router's ability to identify the least busy deployment
|
||||||
|
|
||||||
# #
|
#
|
||||||
# # How is this achieved?
|
# How is this achieved?
|
||||||
# # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||||
# # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||||
# # - use litellm.success + failure callbacks to log when a request completed
|
# - use litellm.success + failure callbacks to log when a request completed
|
||||||
# # - in get_available_deployment, for a given model group name -> pick based on traffic
|
# - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||||
|
|
||||||
# import sys, os, asyncio, time
|
import sys, os, asyncio, time
|
||||||
# import traceback
|
import traceback
|
||||||
# from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# load_dotenv()
|
load_dotenv()
|
||||||
# import os
|
import os
|
||||||
|
|
||||||
# sys.path.insert(
|
sys.path.insert(
|
||||||
# 0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
# ) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
# import pytest
|
import pytest
|
||||||
# from litellm import Router
|
from litellm import Router
|
||||||
# import litellm
|
import litellm
|
||||||
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
# async def test_least_busy_routing():
|
### UNIT TESTS FOR LEAST BUSY LOGGING ###
|
||||||
# model_list = [{
|
|
||||||
# "model_name": "azure-model",
|
|
||||||
# "litellm_params": {
|
|
||||||
# "model": "azure/gpt-turbo",
|
|
||||||
# "api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
|
||||||
# "api_base": "https://openai-france-1234.openai.azure.com",
|
|
||||||
# "rpm": 1440,
|
|
||||||
# }
|
|
||||||
# }, {
|
|
||||||
# "model_name": "azure-model",
|
|
||||||
# "litellm_params": {
|
|
||||||
# "model": "azure/gpt-35-turbo",
|
|
||||||
# "api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
|
||||||
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
|
||||||
# "rpm": 6
|
|
||||||
# }
|
|
||||||
# }, {
|
|
||||||
# "model_name": "azure-model",
|
|
||||||
# "litellm_params": {
|
|
||||||
# "model": "azure/gpt-35-turbo",
|
|
||||||
# "api_key": "os.environ/AZURE_CANADA_API_KEY",
|
|
||||||
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
|
||||||
# "rpm": 6
|
|
||||||
# }
|
|
||||||
# }]
|
|
||||||
# router = Router(model_list=model_list,
|
|
||||||
# routing_strategy="least-busy",
|
|
||||||
# set_verbose=False,
|
|
||||||
# num_retries=3) # type: ignore
|
|
||||||
|
|
||||||
# async def call_azure_completion():
|
|
||||||
# try:
|
|
||||||
# response = await router.acompletion(
|
|
||||||
# model="azure-model",
|
|
||||||
# messages=[
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": "hello this request will pass"
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
# print("\n response", response)
|
|
||||||
# return response
|
|
||||||
# except:
|
|
||||||
# return None
|
|
||||||
|
|
||||||
# n = 1000
|
def test_model_added():
|
||||||
# start_time = time.time()
|
test_cache = DualCache()
|
||||||
# tasks = [call_azure_completion() for _ in range(n)]
|
least_busy_logger = LeastBusyLoggingHandler(router_cache=test_cache)
|
||||||
# chat_completions = await asyncio.gather(*tasks)
|
kwargs = {
|
||||||
# successful_completions = [c for c in chat_completions if c is not None]
|
"litellm_params": {
|
||||||
# print(n, time.time() - start_time, len(successful_completions))
|
"metadata": {
|
||||||
|
"model_group": "gpt-3.5-turbo",
|
||||||
|
"deployment": "azure/chatgpt-v-2",
|
||||||
|
},
|
||||||
|
"model_info": {"id": "1234"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
least_busy_logger.log_pre_api_call(model="test", messages=[], kwargs=kwargs)
|
||||||
|
request_count_api_key = f"gpt-3.5-turbo_request_count"
|
||||||
|
assert test_cache.get_cache(key=request_count_api_key) is not None
|
||||||
|
|
||||||
# asyncio.run(test_least_busy_routing())
|
|
||||||
|
def test_get_available_deployments():
|
||||||
|
test_cache = DualCache()
|
||||||
|
least_busy_logger = LeastBusyLoggingHandler(router_cache=test_cache)
|
||||||
|
model_group = "gpt-3.5-turbo"
|
||||||
|
deployment = "azure/chatgpt-v-2"
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"model_group": model_group,
|
||||||
|
"deployment": deployment,
|
||||||
|
},
|
||||||
|
"model_info": {"id": "1234"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
least_busy_logger.log_pre_api_call(model="test", messages=[], kwargs=kwargs)
|
||||||
|
request_count_api_key = f"{model_group}_request_count"
|
||||||
|
assert test_cache.get_cache(key=request_count_api_key) is not None
|
||||||
|
print(least_busy_logger.get_available_deployments(model_group=model_group))
|
||||||
|
|
||||||
|
|
||||||
|
# test_get_available_deployments()
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_get_available_deployments():
|
||||||
|
"""
|
||||||
|
Tests if 'get_available_deployments' returns the least busy deployment
|
||||||
|
"""
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
|
"rpm": 1440,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||||
|
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||||
|
"rpm": 6,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||||
|
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||||
|
"rpm": 6,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 3},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
routing_strategy="least-busy",
|
||||||
|
set_verbose=False,
|
||||||
|
num_retries=3,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
model_group = "azure-model"
|
||||||
|
deployment = "azure/chatgpt-v-2"
|
||||||
|
request_count_dict = {1: 10, 2: 54, 3: 100}
|
||||||
|
cache_key = f"{model_group}_request_count"
|
||||||
|
router.cache.set_cache(key=cache_key, value=request_count_dict)
|
||||||
|
|
||||||
|
deployment = router.get_available_deployment(model=model_group, messages=None)
|
||||||
|
print(f"deployment: {deployment}")
|
||||||
|
assert deployment["model_info"]["id"] == 1
|
||||||
|
|
||||||
|
## run router completion - assert that the least-busy deployment was incremented
|
||||||
|
|
||||||
|
router.completion(
|
||||||
|
model=model_group,
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
least_busy_dict = router.cache.get_cache(key=cache_key)
|
||||||
|
assert least_busy_dict[1] == 11
|
||||||
|
|
||||||
|
|
||||||
|
# test_router_get_available_deployments()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue