fix(least_busy.py): support consistent use of model id instead of deployment name

This commit is contained in:
Krrish Dholakia 2023-12-29 17:05:14 +05:30
parent 06e4b301b4
commit 678bbfa9be
3 changed files with 144 additions and 98 deletions

View file

@ -1,79 +1,134 @@
# #### What this tests ####
# # This tests the router's ability to identify the least busy deployment
#### What this tests ####
# This tests the router's ability to identify the least busy deployment
# #
# # How is this achieved?
# # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# # - use litellm.success + failure callbacks to log when a request completed
# # - in get_available_deployment, for a given model group name -> pick based on traffic
#
# How is this achieved?
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# - use litellm.success + failure callbacks to log when a request completed
# - in get_available_deployment, for a given model group name -> pick based on traffic
# import sys, os, asyncio, time
# import traceback
# from dotenv import load_dotenv
import sys, os, asyncio, time
import traceback
from dotenv import load_dotenv
# load_dotenv()
# import os
load_dotenv()
import os
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import pytest
# from litellm import Router
# import litellm
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm import Router
import litellm
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.caching import DualCache
# async def test_least_busy_routing():
# model_list = [{
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-turbo",
# "api_key": "os.environ/AZURE_FRANCE_API_KEY",
# "api_base": "https://openai-france-1234.openai.azure.com",
# "rpm": 1440,
# }
# }, {
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-35-turbo",
# "api_key": "os.environ/AZURE_EUROPE_API_KEY",
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
# "rpm": 6
# }
# }, {
# "model_name": "azure-model",
# "litellm_params": {
# "model": "azure/gpt-35-turbo",
# "api_key": "os.environ/AZURE_CANADA_API_KEY",
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
# "rpm": 6
# }
# }]
# router = Router(model_list=model_list,
# routing_strategy="least-busy",
# set_verbose=False,
# num_retries=3) # type: ignore
### UNIT TESTS FOR LEAST BUSY LOGGING ###
# async def call_azure_completion():
# try:
# response = await router.acompletion(
# model="azure-model",
# messages=[
# {
# "role": "user",
# "content": "hello this request will pass"
# }
# ]
# )
# print("\n response", response)
# return response
# except:
# return None
# n = 1000
# start_time = time.time()
# tasks = [call_azure_completion() for _ in range(n)]
# chat_completions = await asyncio.gather(*tasks)
# successful_completions = [c for c in chat_completions if c is not None]
# print(n, time.time() - start_time, len(successful_completions))
def test_model_added():
test_cache = DualCache()
least_busy_logger = LeastBusyLoggingHandler(router_cache=test_cache)
kwargs = {
"litellm_params": {
"metadata": {
"model_group": "gpt-3.5-turbo",
"deployment": "azure/chatgpt-v-2",
},
"model_info": {"id": "1234"},
}
}
least_busy_logger.log_pre_api_call(model="test", messages=[], kwargs=kwargs)
request_count_api_key = f"gpt-3.5-turbo_request_count"
assert test_cache.get_cache(key=request_count_api_key) is not None
# asyncio.run(test_least_busy_routing())
def test_get_available_deployments():
test_cache = DualCache()
least_busy_logger = LeastBusyLoggingHandler(router_cache=test_cache)
model_group = "gpt-3.5-turbo"
deployment = "azure/chatgpt-v-2"
kwargs = {
"litellm_params": {
"metadata": {
"model_group": model_group,
"deployment": deployment,
},
"model_info": {"id": "1234"},
}
}
least_busy_logger.log_pre_api_call(model="test", messages=[], kwargs=kwargs)
request_count_api_key = f"{model_group}_request_count"
assert test_cache.get_cache(key=request_count_api_key) is not None
print(least_busy_logger.get_available_deployments(model_group=model_group))
# test_get_available_deployments()
def test_router_get_available_deployments():
"""
Tests if 'get_available_deployments' returns the least busy deployment
"""
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_CANADA_API_KEY",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 3},
},
]
router = Router(
model_list=model_list,
routing_strategy="least-busy",
set_verbose=False,
num_retries=3,
) # type: ignore
model_group = "azure-model"
deployment = "azure/chatgpt-v-2"
request_count_dict = {1: 10, 2: 54, 3: 100}
cache_key = f"{model_group}_request_count"
router.cache.set_cache(key=cache_key, value=request_count_dict)
deployment = router.get_available_deployment(model=model_group, messages=None)
print(f"deployment: {deployment}")
assert deployment["model_info"]["id"] == 1
## run router completion - assert that the least-busy deployment was incremented
router.completion(
model=model_group,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
least_busy_dict = router.cache.get_cache(key=cache_key)
assert least_busy_dict[1] == 11
# test_router_get_available_deployments()