(test) test_usage_based_routing

This commit is contained in:
ishaan-jaff 2024-01-19 14:59:56 -08:00
parent 06096c7165
commit b3570c94a0

View file

@ -375,3 +375,76 @@ def test_model_group_aliases():
# test_model_group_aliases()
def test_usage_based_routing():
"""
in this test we, have a model group with two models in it, model-a and model-b.
Then at some point, we exceed the TPM limit (set in the litellm_params)
for model-a only; but for model-b we are still under the limit
"""
try:
def get_azure_params(deployment_name: str):
params = {
"model": f"azure/{deployment_name}",
"api_key": os.environ["AZURE_API_KEY"],
"api_version": os.environ["AZURE_API_VERSION"],
"api_base": os.environ["AZURE_API_BASE"],
}
return params
model_list = [
{
"model_name": "azure/gpt-4",
"litellm_params": get_azure_params("chatgpt-low-tpm"),
"tpm": 100,
},
{
"model_name": "azure/gpt-4",
"litellm_params": get_azure_params("chatgpt-high-tpm"),
"tpm": 1000,
},
]
router = Router(
model_list=model_list,
set_verbose=True,
debug_level="DEBUG",
routing_strategy="usage-based-routing",
redis_host=os.environ["REDIS_HOST"],
redis_port=os.environ["REDIS_PORT"],
)
messages = [
{"content": "Tell me a joke.", "role": "user"},
]
selection_counts = defaultdict(int)
for _ in range(25):
response = router.completion(
model="azure/gpt-4",
messages=messages,
timeout=5,
mock_response="good morning",
)
# print(response)
selection_counts[response["model"]] += 1
print(selection_counts)
total_requests = sum(selection_counts.values())
# Assert that 'chatgpt-low-tpm' has more than 2 requests
assert (
selection_counts["chatgpt-low-tpm"] > 2
), f"Assertion failed: 'chatgpt-low-tpm' does not have more than 2 request in the weighted load balancer. Selection counts {selection_counts}"
# Assert that 'chatgpt-high-tpm' has about 80% of the total requests
assert (
selection_counts["chatgpt-high-tpm"] / total_requests > 0.8
), f"Assertion failed: 'chatgpt-high-tpm' does not have about 80% of the total requests in the weighted load balancer. Selection counts {selection_counts}"
except Exception as e:
pytest.fail(f"Error occurred: {e}")