diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 9b2bfbe24..4dc1075a4 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1913,3 +1913,9 @@ class TeamInfoResponseObject(TypedDict): team_info: LiteLLM_TeamTable keys: List team_memberships: List[LiteLLM_TeamMembership] + + +class CurrentItemRateLimit(TypedDict): + current_requests: int + current_tpm: int + current_rpm: int diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 864ad5260..7af2a409c 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -11,7 +11,7 @@ from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs -from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy._types import CurrentItemRateLimit, UserAPIKeyAuth from litellm.proxy.auth.auth_utils import ( get_key_model_rpm_limit, get_key_model_tpm_limit, @@ -754,3 +754,63 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "Parallel Request Limiter: Error getting user object", str(e) ) return None + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response + ): + """ + Retrieve the key's remaining rate limits. + """ + api_key = user_api_key_dict.api_key + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + request_count_api_key = f"{api_key}::{precise_minute}::request_count" + current: Optional[CurrentItemRateLimit] = ( + await self.internal_usage_cache.async_get_cache( + key=request_count_api_key, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + ) + ) + + key_remaining_rpm_limit: Optional[int] = None + key_rpm_limit: Optional[int] = None + key_remaining_tpm_limit: Optional[int] = None + key_tpm_limit: Optional[int] = None + if current is not None: + if user_api_key_dict.rpm_limit is not None: + key_remaining_rpm_limit = ( + user_api_key_dict.rpm_limit - current["current_rpm"] + ) + key_rpm_limit = user_api_key_dict.rpm_limit + if user_api_key_dict.tpm_limit is not None: + key_remaining_tpm_limit = ( + user_api_key_dict.tpm_limit - current["current_tpm"] + ) + key_tpm_limit = user_api_key_dict.tpm_limit + + _hidden_params = getattr(response, "_hidden_params", {}) or {} + _additional_headers = _hidden_params.get("additional_headers", {}) or {} + if key_remaining_rpm_limit is not None: + _additional_headers["x-ratelimit-remaining-requests"] = ( + key_remaining_rpm_limit + ) + if key_rpm_limit is not None: + _additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit + if key_remaining_tpm_limit is not None: + _additional_headers["x-ratelimit-remaining-tokens"] = ( + key_remaining_tpm_limit + ) + if key_tpm_limit is not None: + _additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit + + setattr( + response, + "_hidden_params", + {**_hidden_params, "additional_headers": _additional_headers}, + ) + + return await super().async_post_call_success_hook( + data, user_api_key_dict, response + ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8c5f91c15..e6c9bcad2 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -35,7 +35,13 @@ from typing_extensions import overload import litellm import litellm.litellm_core_utils import litellm.litellm_core_utils.litellm_logging -from litellm import EmbeddingResponse, ImageResponse, ModelResponse, get_litellm_params +from litellm import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + Router, + get_litellm_params, +) from litellm._logging import verbose_proxy_logger from litellm._service_logger import ServiceLogging, ServiceTypes from litellm.caching import DualCache, RedisCache diff --git a/tests/local_testing/test_parallel_request_limiter.py b/tests/local_testing/test_parallel_request_limiter.py index dee7fefbf..c694cdc07 100644 --- a/tests/local_testing/test_parallel_request_limiter.py +++ b/tests/local_testing/test_parallel_request_limiter.py @@ -1203,3 +1203,95 @@ async def test_pre_call_hook_tpm_limits_per_model(): "request limit reached Hit TPM limit for model: azure-model on api_key" in str(e) ) + + +@pytest.mark.asyncio +@pytest.mark.flaky(retries=6, delay=1) +async def test_post_call_success_hook_rpm_limits_per_model(): + """ + Test if openai-compatible x-ratelimit-* headers are added to the response + """ + import logging + from litellm import ModelResponse + + from litellm._logging import ( + verbose_logger, + verbose_proxy_logger, + verbose_router_logger, + ) + + verbose_logger.setLevel(logging.DEBUG) + verbose_proxy_logger.setLevel(logging.DEBUG) + verbose_router_logger.setLevel(logging.DEBUG) + + _api_key = "sk-12345" + _api_key = hash_token(_api_key) + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, + max_parallel_requests=100, + tpm_limit=900000, + rpm_limit=100000, + metadata={ + "model_tpm_limit": {"azure-model": 1}, + "model_rpm_limit": {"azure-model": 100}, + }, + ) + local_cache = DualCache() + pl = ProxyLogging(user_api_key_cache=local_cache) + pl._init_litellm_callbacks() + print(f"litellm callbacks: {litellm.callbacks}") + parallel_request_handler = pl.max_parallel_request_limiter + model = "azure-model" + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"model": model}, + call_type="", + ) + + kwargs = { + "model": model, + "litellm_params": { + "metadata": { + "user_api_key": _api_key, + "model_group": model, + "user_api_key_metadata": { + "model_tpm_limit": {"azure-model": 1}, + "model_rpm_limit": {"azure-model": 100}, + }, + } + }, + } + + await parallel_request_handler.async_log_success_event( + kwargs=kwargs, + response_obj=litellm.ModelResponse(usage=litellm.Usage(total_tokens=11)), + start_time="", + end_time="", + ) + + current_date = datetime.now().strftime("%Y-%m-%d") + current_hour = datetime.now().strftime("%H") + current_minute = datetime.now().strftime("%M") + precise_minute = f"{current_date}-{current_hour}-{current_minute}" + request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count" + + print(f"request_count_api_key: {request_count_api_key}") + current_cache = parallel_request_handler.internal_usage_cache.get_cache( + key=request_count_api_key + ) + print("current cache: ", current_cache) + + response = ModelResponse() + await parallel_request_handler.async_post_call_success_hook( + data={}, user_api_key_dict=user_api_key_dict, response=response + ) + + hidden_params = getattr(response, "_hidden_params", {}) or {} + print(hidden_params) + assert "additional_headers" in hidden_params + assert "x-ratelimit-limit-requests" in hidden_params["additional_headers"] + assert "x-ratelimit-remaining-requests" in hidden_params["additional_headers"] + assert "x-ratelimit-limit-tokens" in hidden_params["additional_headers"] + assert "x-ratelimit-remaining-tokens" in hidden_params["additional_headers"]