mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-28 04:04:31 +00:00
(test) test_usage_based_routing
This commit is contained in:
parent
06096c7165
commit
b3570c94a0
1 changed files with 73 additions and 0 deletions
|
@ -375,3 +375,76 @@ def test_model_group_aliases():
|
||||||
|
|
||||||
|
|
||||||
# test_model_group_aliases()
|
# test_model_group_aliases()
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_based_routing():
|
||||||
|
"""
|
||||||
|
in this test we, have a model group with two models in it, model-a and model-b.
|
||||||
|
Then at some point, we exceed the TPM limit (set in the litellm_params)
|
||||||
|
for model-a only; but for model-b we are still under the limit
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
|
||||||
|
def get_azure_params(deployment_name: str):
|
||||||
|
params = {
|
||||||
|
"model": f"azure/{deployment_name}",
|
||||||
|
"api_key": os.environ["AZURE_API_KEY"],
|
||||||
|
"api_version": os.environ["AZURE_API_VERSION"],
|
||||||
|
"api_base": os.environ["AZURE_API_BASE"],
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure/gpt-4",
|
||||||
|
"litellm_params": get_azure_params("chatgpt-low-tpm"),
|
||||||
|
"tpm": 100,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure/gpt-4",
|
||||||
|
"litellm_params": get_azure_params("chatgpt-high-tpm"),
|
||||||
|
"tpm": 1000,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
set_verbose=True,
|
||||||
|
debug_level="DEBUG",
|
||||||
|
routing_strategy="usage-based-routing",
|
||||||
|
redis_host=os.environ["REDIS_HOST"],
|
||||||
|
redis_port=os.environ["REDIS_PORT"],
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"content": "Tell me a joke.", "role": "user"},
|
||||||
|
]
|
||||||
|
|
||||||
|
selection_counts = defaultdict(int)
|
||||||
|
for _ in range(25):
|
||||||
|
response = router.completion(
|
||||||
|
model="azure/gpt-4",
|
||||||
|
messages=messages,
|
||||||
|
timeout=5,
|
||||||
|
mock_response="good morning",
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(response)
|
||||||
|
|
||||||
|
selection_counts[response["model"]] += 1
|
||||||
|
|
||||||
|
print(selection_counts)
|
||||||
|
|
||||||
|
total_requests = sum(selection_counts.values())
|
||||||
|
|
||||||
|
# Assert that 'chatgpt-low-tpm' has more than 2 requests
|
||||||
|
assert (
|
||||||
|
selection_counts["chatgpt-low-tpm"] > 2
|
||||||
|
), f"Assertion failed: 'chatgpt-low-tpm' does not have more than 2 request in the weighted load balancer. Selection counts {selection_counts}"
|
||||||
|
|
||||||
|
# Assert that 'chatgpt-high-tpm' has about 80% of the total requests
|
||||||
|
assert (
|
||||||
|
selection_counts["chatgpt-high-tpm"] / total_requests > 0.8
|
||||||
|
), f"Assertion failed: 'chatgpt-high-tpm' does not have about 80% of the total requests in the weighted load balancer. Selection counts {selection_counts}"
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue