diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 4ee3bd5a14..4af86c24c4 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -215,6 +215,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): current = await self.internal_usage_cache.async_get_cache( key=request_count_api_key ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} + tpm_limit_for_model = None rpm_limit_for_model = None @@ -237,8 +238,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): request_count_api_key, new_val ) elif tpm_limit_for_model is not None or rpm_limit_for_model is not None: + # Increase count for this token new_val = { - "current_requests": 1, + "current_requests": current["current_requests"] + 1, "current_tpm": current["current_tpm"], "current_rpm": current["current_rpm"], } @@ -247,14 +249,18 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): and current["current_tpm"] >= tpm_limit_for_model ): return self.raise_rate_limit_error( - additional_details=f"Hit limit for model: {_model} on api_key: {api_key}. tpm_limit: {tpm_limit_for_model}, current_tpm {current['current_tpm']} " + additional_details=f"Hit TPM limit for model: {_model} on api_key: {api_key}. tpm_limit: {tpm_limit_for_model}, current_tpm {current['current_tpm']} " ) elif ( rpm_limit_for_model is not None and current["current_rpm"] >= rpm_limit_for_model ): return self.raise_rate_limit_error( - additional_details=f"Hit limit for model: {_model} on api_key: {api_key}. rpm_limit: {rpm_limit_for_model}, current_rpm {current['current_rpm']} " + additional_details=f"Hit RPM limit for model: {_model} on api_key: {api_key}. rpm_limit: {rpm_limit_for_model}, current_rpm {current['current_rpm']} " + ) + else: + await self.internal_usage_cache.async_set_cache( + request_count_api_key, new_val ) # check if REQUEST ALLOWED for user_id diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index e43cfd7778..e6ffa272f7 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -1000,3 +1000,171 @@ async def test_bad_router_tpm_limit_per_model(): )["current_tpm"] == 0 ) + + +@pytest.mark.asyncio +async def test_pre_call_hook_rpm_limits_per_model(): + """ + Test if error raised on hitting rpm limits for a given model + """ + import logging + + from litellm._logging import ( + verbose_logger, + verbose_proxy_logger, + verbose_router_logger, + ) + + verbose_logger.setLevel(logging.DEBUG) + verbose_proxy_logger.setLevel(logging.DEBUG) + verbose_router_logger.setLevel(logging.DEBUG) + + _api_key = "sk-12345" + _api_key = hash_token(_api_key) + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, + max_parallel_requests=100, + tpm_limit=900000, + rpm_limit=100000, + rpm_limit_per_model={"azure-model": 1}, + ) + local_cache = DualCache() + pl = ProxyLogging(user_api_key_cache=local_cache) + pl._init_litellm_callbacks() + print(f"litellm callbacks: {litellm.callbacks}") + parallel_request_handler = pl.max_parallel_request_limiter + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + model = "azure-model" + + kwargs = { + "model": model, + "litellm_params": {"metadata": {"user_api_key": _api_key}}, + } + + await parallel_request_handler.async_log_success_event( + kwargs=kwargs, + response_obj="", + start_time="", + end_time="", + ) + + ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} + + try: + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"model": model}, + call_type="", + ) + + pytest.fail(f"Expected call to fail") + except Exception as e: + assert e.status_code == 429 + print("got error=", e) + assert ( + "limit reached Hit RPM limit for model: azure-model on api_key: c11e7177eb60c80cf983ddf8ca98f2dc1272d4c612204ce9bedd2460b18939cc" + in str(e) + ) + + +@pytest.mark.asyncio +async def test_pre_call_hook_tpm_limits_per_model(): + """ + Test if error raised on hitting tpm limits for a given model + """ + import logging + + from litellm._logging import ( + verbose_logger, + verbose_proxy_logger, + verbose_router_logger, + ) + + verbose_logger.setLevel(logging.DEBUG) + verbose_proxy_logger.setLevel(logging.DEBUG) + verbose_router_logger.setLevel(logging.DEBUG) + + _api_key = "sk-12345" + _api_key = hash_token(_api_key) + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, + max_parallel_requests=100, + tpm_limit=900000, + rpm_limit=100000, + rpm_limit_per_model={"azure-model": 100}, + tpm_limit_per_model={"azure-model": 10}, + ) + local_cache = DualCache() + pl = ProxyLogging(user_api_key_cache=local_cache) + pl._init_litellm_callbacks() + print(f"litellm callbacks: {litellm.callbacks}") + parallel_request_handler = pl.max_parallel_request_limiter + model = "azure-model" + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"model": model}, + call_type="", + ) + + kwargs = { + "model": model, + "litellm_params": {"metadata": {"user_api_key": _api_key}}, + } + + await parallel_request_handler.async_log_success_event( + kwargs=kwargs, + response_obj=litellm.ModelResponse(usage=litellm.Usage(total_tokens=11)), + start_time="", + end_time="", + ) + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count" + + print( + "internal usage cache: ", + parallel_request_handler.internal_usage_cache.in_memory_cache.cache_dict, + ) + + assert ( + parallel_request_handler.internal_usage_cache.get_cache( + key=request_count_api_key + )["current_tpm"] + == 11 + ) + + assert ( + parallel_request_handler.internal_usage_cache.get_cache( + key=request_count_api_key + )["current_rpm"] + == 1 + ) + + ## Expected cache val: {"current_requests": 0, "current_tpm": 11, "current_rpm": "1"} + + try: + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"model": model}, + call_type="", + ) + + pytest.fail(f"Expected call to fail") + except Exception as e: + assert e.status_code == 429 + print("got error=", e) + assert ( + "request limit reached Hit TPM limit for model: azure-model on api_key" + in str(e) + )