forked from phoenix/litellm-mirror
fix(router.py): add random shuffle and tpm-based shuffle for async shuffle logic
This commit is contained in:
parent
c015e5e2c6
commit
a520e1bd6f
4 changed files with 147 additions and 1 deletions
|
@ -29,7 +29,7 @@ model_list:
|
|||
# api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2
|
||||
# routing_strategy: usage-based-routing-v2
|
||||
# redis_url: "os.environ/REDIS_URL"
|
||||
redis_host: os.environ/REDIS_HOST
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
|
|
|
@ -2872,7 +2872,27 @@ class Router:
|
|||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
|
||||
)
|
||||
return deployment or deployment[0]
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
tpm = healthy_deployments[0].get("litellm_params").get("tpm", None)
|
||||
if tpm is not None:
|
||||
# use weight-random pick if rpms provided
|
||||
tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments]
|
||||
verbose_router_logger.debug(f"\ntpms {tpms}")
|
||||
total_tpm = sum(tpms)
|
||||
weights = [tpm / total_tpm for tpm in tpms]
|
||||
verbose_router_logger.debug(f"\n weights {weights}")
|
||||
# Perform weighted random pick
|
||||
selected_index = random.choices(range(len(tpms)), weights=weights)[0]
|
||||
verbose_router_logger.debug(f"\n selected index, {selected_index}")
|
||||
deployment = healthy_deployments[selected_index]
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
|
||||
)
|
||||
return deployment or deployment[0]
|
||||
|
||||
############## No RPM/TPM passed, we do a random pick #################
|
||||
item = random.choice(healthy_deployments)
|
||||
return item or item[0]
|
||||
if deployment is None:
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
|
|
|
@ -512,3 +512,76 @@ async def test_wildcard_openai_routing():
|
|||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
"""
|
||||
Test async router get deployment (Simpl-shuffle)
|
||||
"""
|
||||
|
||||
rpm_list = [[None, None], [6, 1440]]
|
||||
tpm_list = [[None, None], [6, 1440]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"rpm_list, tpm_list",
|
||||
[(rpm, tpm) for rpm in rpm_list for tpm in tpm_list],
|
||||
)
|
||||
async def test_weighted_selection_router_async(rpm_list, tpm_list):
|
||||
# this tests if load balancing works based on the provided rpms in the router
|
||||
# it's a fast test, only tests get_available_deployment
|
||||
# users can pass rpms as a litellm_param
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"rpm": rpm_list[0],
|
||||
"tpm": tpm_list[0],
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"rpm": rpm_list[1],
|
||||
"tpm": tpm_list[1],
|
||||
},
|
||||
},
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
)
|
||||
selection_counts = defaultdict(int)
|
||||
|
||||
# call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time
|
||||
for _ in range(1000):
|
||||
selected_model = await router.async_get_available_deployment(
|
||||
"gpt-3.5-turbo"
|
||||
)
|
||||
selected_model_id = selected_model["litellm_params"]["model"]
|
||||
selected_model_name = selected_model_id
|
||||
selection_counts[selected_model_name] += 1
|
||||
print(selection_counts)
|
||||
|
||||
total_requests = sum(selection_counts.values())
|
||||
|
||||
if rpm_list[0] is not None or tpm_list[0] is not None:
|
||||
# Assert that 'azure/chatgpt-v-2' has about 90% of the total requests
|
||||
assert (
|
||||
selection_counts["azure/chatgpt-v-2"] / total_requests > 0.89
|
||||
), f"Assertion failed: 'azure/chatgpt-v-2' does not have about 90% of the total requests in the weighted load balancer. Selection counts {selection_counts}"
|
||||
else:
|
||||
# Assert both are used
|
||||
assert selection_counts["azure/chatgpt-v-2"] > 0
|
||||
assert selection_counts["gpt-3.5-turbo-0613"] > 0
|
||||
router.reset()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
53
litellm/tests/test_simple_shuffle.py
Normal file
53
litellm/tests/test_simple_shuffle.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
# What is this?
|
||||
## unit tests for 'simple-shuffle'
|
||||
|
||||
import sys, os, asyncio, time, random
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
from litellm import Router
|
||||
|
||||
"""
|
||||
Test random shuffle
|
||||
- async
|
||||
- sync
|
||||
"""
|
||||
|
||||
|
||||
async def test_simple_shuffle():
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-turbo",
|
||||
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||
"rpm": 1440,
|
||||
},
|
||||
"model_info": {"id": 1},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 2},
|
||||
},
|
||||
]
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
routing_strategy="usage-based-routing-v2",
|
||||
set_verbose=False,
|
||||
num_retries=3,
|
||||
) # type: ignore
|
Loading…
Add table
Add a link
Reference in a new issue