forked from phoenix/litellm-mirror
fix(least_busy.py): support consistent use of model id instead of deployment name
This commit is contained in:
parent
06e4b301b4
commit
678bbfa9be
3 changed files with 144 additions and 98 deletions
|
@ -287,7 +287,7 @@ class Router:
|
|||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
# response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout)
|
||||
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
|
@ -1664,6 +1664,7 @@ class Router:
|
|||
deployments = self.leastbusy_logger.get_available_deployments(
|
||||
model_group=model
|
||||
)
|
||||
self.print_verbose(f"deployments in least-busy router: {deployments}")
|
||||
# pick least busy deployment
|
||||
min_traffic = float("inf")
|
||||
min_deployment = None
|
||||
|
@ -1671,14 +1672,19 @@ class Router:
|
|||
if v < min_traffic:
|
||||
min_traffic = v
|
||||
min_deployment = k
|
||||
self.print_verbose(f"min_deployment: {min_deployment};")
|
||||
############## No Available Deployments passed, we do a random pick #################
|
||||
if min_deployment is None:
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
############## Available Deployments passed, we find the relevant item #################
|
||||
else:
|
||||
## check if min deployment is a string, if so, cast it to int
|
||||
if isinstance(min_deployment, str):
|
||||
min_deployment = int(min_deployment)
|
||||
for m in healthy_deployments:
|
||||
if m["model_info"]["id"] == min_deployment:
|
||||
return m
|
||||
self.print_verbose(f"no healthy deployment with that id found!")
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
return min_deployment
|
||||
elif self.routing_strategy == "simple-shuffle":
|
||||
|
|
|
@ -30,27 +30,20 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs["litellm_params"]["metadata"].get(
|
||||
"deployment", None
|
||||
)
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if deployment is None or model_group is None or id is None:
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
|
||||
# map deployment to id
|
||||
self.mapping_deployment_to_id[deployment] = id
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# update cache
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_dict[deployment] = (
|
||||
request_count_dict.get(deployment, 0) + 1
|
||||
)
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
@ -62,13 +55,12 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs["litellm_params"]["metadata"].get(
|
||||
"deployment", None
|
||||
)
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
if deployment is None or model_group is None:
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
|
@ -76,7 +68,7 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_dict[deployment] = request_count_dict.get(deployment)
|
||||
request_count_dict[id] = request_count_dict.get(id) - 1
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
@ -88,13 +80,11 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
deployment = kwargs["litellm_params"]["metadata"].get(
|
||||
"deployment", None
|
||||
)
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
if deployment is None or model_group is None:
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
|
@ -102,7 +92,7 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_dict[deployment] = request_count_dict.get(deployment)
|
||||
request_count_dict[id] = request_count_dict.get(id) - 1
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
@ -111,11 +101,6 @@ class LeastBusyLoggingHandler(CustomLogger):
|
|||
|
||||
def get_available_deployments(self, model_group: str):
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
return_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
# map deployment to id
|
||||
return_dict = {}
|
||||
for key, value in request_count_dict.items():
|
||||
return_dict[self.mapping_deployment_to_id[key]] = value
|
||||
return return_dict
|
||||
|
|
|
@ -1,79 +1,134 @@
|
|||
# #### What this tests ####
|
||||
# # This tests the router's ability to identify the least busy deployment
|
||||
#### What this tests ####
|
||||
# This tests the router's ability to identify the least busy deployment
|
||||
|
||||
# #
|
||||
# # How is this achieved?
|
||||
# # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||
# # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||
# # - use litellm.success + failure callbacks to log when a request completed
|
||||
# # - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||
#
|
||||
# How is this achieved?
|
||||
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||
# - use litellm.success + failure callbacks to log when a request completed
|
||||
# - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||
|
||||
# import sys, os, asyncio, time
|
||||
# import traceback
|
||||
# from dotenv import load_dotenv
|
||||
import sys, os, asyncio, time
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# load_dotenv()
|
||||
# import os
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import pytest
|
||||
# from litellm import Router
|
||||
# import litellm
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
from litellm import Router
|
||||
import litellm
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.caching import DualCache
|
||||
|
||||
# async def test_least_busy_routing():
|
||||
# model_list = [{
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-turbo",
|
||||
# "api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
# "api_base": "https://openai-france-1234.openai.azure.com",
|
||||
# "rpm": 1440,
|
||||
# }
|
||||
# }, {
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-35-turbo",
|
||||
# "api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
# "rpm": 6
|
||||
# }
|
||||
# }, {
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-35-turbo",
|
||||
# "api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||
# "rpm": 6
|
||||
# }
|
||||
# }]
|
||||
# router = Router(model_list=model_list,
|
||||
# routing_strategy="least-busy",
|
||||
# set_verbose=False,
|
||||
# num_retries=3) # type: ignore
|
||||
### UNIT TESTS FOR LEAST BUSY LOGGING ###
|
||||
|
||||
# async def call_azure_completion():
|
||||
# try:
|
||||
# response = await router.acompletion(
|
||||
# model="azure-model",
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "hello this request will pass"
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# print("\n response", response)
|
||||
# return response
|
||||
# except:
|
||||
# return None
|
||||
|
||||
# n = 1000
|
||||
# start_time = time.time()
|
||||
# tasks = [call_azure_completion() for _ in range(n)]
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
# print(n, time.time() - start_time, len(successful_completions))
|
||||
def test_model_added():
|
||||
test_cache = DualCache()
|
||||
least_busy_logger = LeastBusyLoggingHandler(router_cache=test_cache)
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"model_group": "gpt-3.5-turbo",
|
||||
"deployment": "azure/chatgpt-v-2",
|
||||
},
|
||||
"model_info": {"id": "1234"},
|
||||
}
|
||||
}
|
||||
least_busy_logger.log_pre_api_call(model="test", messages=[], kwargs=kwargs)
|
||||
request_count_api_key = f"gpt-3.5-turbo_request_count"
|
||||
assert test_cache.get_cache(key=request_count_api_key) is not None
|
||||
|
||||
# asyncio.run(test_least_busy_routing())
|
||||
|
||||
def test_get_available_deployments():
|
||||
test_cache = DualCache()
|
||||
least_busy_logger = LeastBusyLoggingHandler(router_cache=test_cache)
|
||||
model_group = "gpt-3.5-turbo"
|
||||
deployment = "azure/chatgpt-v-2"
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"model_group": model_group,
|
||||
"deployment": deployment,
|
||||
},
|
||||
"model_info": {"id": "1234"},
|
||||
}
|
||||
}
|
||||
least_busy_logger.log_pre_api_call(model="test", messages=[], kwargs=kwargs)
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
assert test_cache.get_cache(key=request_count_api_key) is not None
|
||||
print(least_busy_logger.get_available_deployments(model_group=model_group))
|
||||
|
||||
|
||||
# test_get_available_deployments()
|
||||
|
||||
|
||||
def test_router_get_available_deployments():
|
||||
"""
|
||||
Tests if 'get_available_deployments' returns the least busy deployment
|
||||
"""
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-turbo",
|
||||
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||
"rpm": 1440,
|
||||
},
|
||||
"model_info": {"id": 1},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 2},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 3},
|
||||
},
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="least-busy",
|
||||
set_verbose=False,
|
||||
num_retries=3,
|
||||
) # type: ignore
|
||||
|
||||
model_group = "azure-model"
|
||||
deployment = "azure/chatgpt-v-2"
|
||||
request_count_dict = {1: 10, 2: 54, 3: 100}
|
||||
cache_key = f"{model_group}_request_count"
|
||||
router.cache.set_cache(key=cache_key, value=request_count_dict)
|
||||
|
||||
deployment = router.get_available_deployment(model=model_group, messages=None)
|
||||
print(f"deployment: {deployment}")
|
||||
assert deployment["model_info"]["id"] == 1
|
||||
|
||||
## run router completion - assert that the least-busy deployment was incremented
|
||||
|
||||
router.completion(
|
||||
model=model_group,
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
|
||||
least_busy_dict = router.cache.get_cache(key=cache_key)
|
||||
assert least_busy_dict[1] == 11
|
||||
|
||||
|
||||
# test_router_get_available_deployments()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue