From 398295116f38346333d1dee2bd7cc09e6ec72e96 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sun, 18 Aug 2024 09:57:31 -0700 Subject: [PATCH] inly write model tpm/rpm tracking when user set it --- .../proxy/hooks/parallel_request_limiter.py | 14 ++++++++++- .../tests/test_parallel_request_limiter.py | 23 ++++++++++++++++--- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index e813d2597..38b57c19e 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -400,6 +400,11 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ) user_api_key_end_user_id = kwargs.get("user") + user_api_key_metadata = ( + kwargs["litellm_params"]["metadata"].get("user_api_key_metadata", {}) + or {} + ) + # ------------ # Setup values # ------------ @@ -456,7 +461,14 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # Update usage - model group + API Key # ------------ model_group = get_model_group_from_litellm_kwargs(kwargs) - if user_api_key is not None and model_group is not None: + if ( + user_api_key is not None + and model_group is not None + and ( + "model_rpm_limit" in user_api_key_metadata + or "model_tpm_limit" in user_api_key_metadata + ) + ): request_count_api_key = ( f"{user_api_key}::{model_group}::{precise_minute}::request_count" ) diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index 38b85bf7f..463f9cf50 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -990,7 +990,13 @@ async def test_bad_router_tpm_limit_per_model(): model=model, messages=[{"role": "user2", "content": "Write me a paragraph on the moon"}], stream=True, - metadata={"user_api_key": _api_key}, + metadata={ + "user_api_key": _api_key, + "user_api_key_metadata": { + "model_rpm_limit": {model: 5}, + "model_tpm_limit": {model: 5}, + }, + }, ) except: pass @@ -1047,7 +1053,11 @@ async def test_pre_call_hook_rpm_limits_per_model(): kwargs = { "model": model, "litellm_params": { - "metadata": {"user_api_key": _api_key, "model_group": model} + "metadata": { + "user_api_key": _api_key, + "model_group": model, + "user_api_key_metadata": {"model_rpm_limit": {"azure-model": 1}}, + }, }, } @@ -1124,7 +1134,14 @@ async def test_pre_call_hook_tpm_limits_per_model(): kwargs = { "model": model, "litellm_params": { - "metadata": {"user_api_key": _api_key, "model_group": model} + "metadata": { + "user_api_key": _api_key, + "model_group": model, + "user_api_key_metadata": { + "model_tpm_limit": {"azure-model": 1}, + "model_rpm_limit": {"azure-model": 100}, + }, + } }, }