track rpm/tpm usage per key+model

This commit is contained in:
Ishaan Jaff 2024-08-16 18:28:58 -07:00
parent a6a4b944ad
commit 1ee33478c9
2 changed files with 177 additions and 0 deletions

View file

@ -908,3 +908,95 @@ async def test_bad_router_tpm_limit():
)["current_tpm"]
== 0
)
@pytest.mark.asyncio
async def test_bad_router_tpm_limit_per_model():
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
set_verbose=False,
num_retries=3,
) # type: ignore
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
model = "azure-model"
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
max_parallel_requests=10,
tpm_limit=10,
tpm_limit_per_model={model: 5},
rpm_limit_per_model={model: 5},
)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
pl._init_litellm_callbacks()
print(f"litellm callbacks: {litellm.callbacks}")
parallel_request_handler = pl.max_parallel_request_limiter
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={"model": model},
call_type="",
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count"
print(
"internal usage cache: ",
parallel_request_handler.internal_usage_cache.in_memory_cache.cache_dict,
)
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_requests"]
== 1
)
# bad call
try:
response = await router.acompletion(
model=model,
messages=[{"role": "user2", "content": "Write me a paragraph on the moon"}],
stream=True,
metadata={"user_api_key": _api_key},
)
except:
pass
await asyncio.sleep(1) # success is done in a separate thread
assert (
parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)["current_tpm"]
== 0
)