litellm/tests/test_ratelimit.py
2024-04-02 13:49:13 -07:00

99 lines
2.8 KiB
Python

# %%
import asyncio
import os
import pytest
import random
from typing import Any
from pydantic import BaseModel
from litellm import utils, Router
COMPLETION_TOKENS = 5
base_model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
"max_tokens": COMPLETION_TOKENS,
},
}
]
class RouterConfig(BaseModel):
rpm: int
tpm: int
@pytest.fixture(scope="function")
def router_factory():
def create_router(rpm, tpm):
model_list = base_model_list.copy()
model_list[0]["rpm"] = rpm
model_list[0]["tpm"] = tpm
return Router(
model_list=model_list,
routing_strategy="usage-based-routing",
debug_level="DEBUG",
)
return create_router
def generate_list_of_messages(num_messages):
return [
[{"role": "user", "content": f"{i}. Hey, how's it going? {random.random()}"}]
for i in range(num_messages)
]
def calculate_limits(list_of_messages):
rpm = len(list_of_messages)
tpm = sum((utils.token_counter(messages=m) + COMPLETION_TOKENS for m in list_of_messages))
return rpm, tpm
async def async_call(router: Router, list_of_messages) -> Any:
tasks = [router.acompletion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages]
return await asyncio.gather(*tasks)
def sync_call(router: Router, list_of_messages) -> Any:
return [router.completion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages]
class ExpectNoException(Exception):
pass
@pytest.mark.parametrize(
"num_messages, num_rate_limits",
[
# (2, 30), # No exception expected
# (2, 3), # No exception expected
(2, 2), # No exception expected
(3, 2), # Expect ValueError
# (6, 5), # Expect ValueError
],
)
@pytest.mark.parametrize("sync_mode", [True, False]) # Use parametrization for sync/async
def test_rate_limit(router_factory, num_messages, num_rate_limits, sync_mode):
expected_exception = ExpectNoException if num_messages <= num_rate_limits else ValueError
list_of_messages = generate_list_of_messages(max(num_messages, num_rate_limits))
rpm, tpm = calculate_limits(list_of_messages[:num_rate_limits])
list_of_messages = list_of_messages[:num_messages]
router = router_factory(rpm, tpm)
with pytest.raises(expected_exception) as excinfo:
if sync_mode:
results = sync_call(router, list_of_messages)
else:
results = asyncio.run(async_call(router, list_of_messages))
raise ExpectNoException
if expected_exception is not ExpectNoException:
assert "No deployments available for selected model" in str(excinfo.value)
else:
assert len([i for i in results if i is not None]) == num_messages