fix(parallel_request_limiter.py): return remaining tpm/rpm in openai-compatible way

Fixes https://github.com/BerriAI/litellm/issues/5957
This commit is contained in:
Krrish Dholakia 2024-09-28 15:56:12 -07:00
parent c0cdc6e496
commit 5222fc8e1b
4 changed files with 166 additions and 2 deletions

View file

@ -1913,3 +1913,9 @@ class TeamInfoResponseObject(TypedDict):
team_info: LiteLLM_TeamTable team_info: LiteLLM_TeamTable
keys: List keys: List
team_memberships: List[LiteLLM_TeamMembership] team_memberships: List[LiteLLM_TeamMembership]
class CurrentItemRateLimit(TypedDict):
current_requests: int
current_tpm: int
current_rpm: int

View file

@ -11,7 +11,7 @@ from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import CurrentItemRateLimit, UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import ( from litellm.proxy.auth.auth_utils import (
get_key_model_rpm_limit, get_key_model_rpm_limit,
get_key_model_tpm_limit, get_key_model_tpm_limit,
@ -754,3 +754,63 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
"Parallel Request Limiter: Error getting user object", str(e) "Parallel Request Limiter: Error getting user object", str(e)
) )
return None return None
async def async_post_call_success_hook(
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
):
"""
Retrieve the key's remaining rate limits.
"""
api_key = user_api_key_dict.api_key
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
current: Optional[CurrentItemRateLimit] = (
await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
)
key_remaining_rpm_limit: Optional[int] = None
key_rpm_limit: Optional[int] = None
key_remaining_tpm_limit: Optional[int] = None
key_tpm_limit: Optional[int] = None
if current is not None:
if user_api_key_dict.rpm_limit is not None:
key_remaining_rpm_limit = (
user_api_key_dict.rpm_limit - current["current_rpm"]
)
key_rpm_limit = user_api_key_dict.rpm_limit
if user_api_key_dict.tpm_limit is not None:
key_remaining_tpm_limit = (
user_api_key_dict.tpm_limit - current["current_tpm"]
)
key_tpm_limit = user_api_key_dict.tpm_limit
_hidden_params = getattr(response, "_hidden_params", {}) or {}
_additional_headers = _hidden_params.get("additional_headers", {}) or {}
if key_remaining_rpm_limit is not None:
_additional_headers["x-ratelimit-remaining-requests"] = (
key_remaining_rpm_limit
)
if key_rpm_limit is not None:
_additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit
if key_remaining_tpm_limit is not None:
_additional_headers["x-ratelimit-remaining-tokens"] = (
key_remaining_tpm_limit
)
if key_tpm_limit is not None:
_additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit
setattr(
response,
"_hidden_params",
{**_hidden_params, "additional_headers": _additional_headers},
)
return await super().async_post_call_success_hook(
data, user_api_key_dict, response
)

View file

@ -35,7 +35,13 @@ from typing_extensions import overload
import litellm import litellm
import litellm.litellm_core_utils import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging import litellm.litellm_core_utils.litellm_logging
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, get_litellm_params from litellm import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
Router,
get_litellm_params,
)
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging, ServiceTypes from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm.caching import DualCache, RedisCache from litellm.caching import DualCache, RedisCache

View file

@ -1203,3 +1203,95 @@ async def test_pre_call_hook_tpm_limits_per_model():
"request limit reached Hit TPM limit for model: azure-model on api_key" "request limit reached Hit TPM limit for model: azure-model on api_key"
in str(e) in str(e)
) )
@pytest.mark.asyncio
@pytest.mark.flaky(retries=6, delay=1)
async def test_post_call_success_hook_rpm_limits_per_model():
"""
Test if openai-compatible x-ratelimit-* headers are added to the response
"""
import logging
from litellm import ModelResponse
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
verbose_logger.setLevel(logging.DEBUG)
verbose_proxy_logger.setLevel(logging.DEBUG)
verbose_router_logger.setLevel(logging.DEBUG)
_api_key = "sk-12345"
_api_key = hash_token(_api_key)
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
max_parallel_requests=100,
tpm_limit=900000,
rpm_limit=100000,
metadata={
"model_tpm_limit": {"azure-model": 1},
"model_rpm_limit": {"azure-model": 100},
},
)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
pl._init_litellm_callbacks()
print(f"litellm callbacks: {litellm.callbacks}")
parallel_request_handler = pl.max_parallel_request_limiter
model = "azure-model"
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={"model": model},
call_type="",
)
kwargs = {
"model": model,
"litellm_params": {
"metadata": {
"user_api_key": _api_key,
"model_group": model,
"user_api_key_metadata": {
"model_tpm_limit": {"azure-model": 1},
"model_rpm_limit": {"azure-model": 100},
},
}
},
}
await parallel_request_handler.async_log_success_event(
kwargs=kwargs,
response_obj=litellm.ModelResponse(usage=litellm.Usage(total_tokens=11)),
start_time="",
end_time="",
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{_api_key}::{model}::{precise_minute}::request_count"
print(f"request_count_api_key: {request_count_api_key}")
current_cache = parallel_request_handler.internal_usage_cache.get_cache(
key=request_count_api_key
)
print("current cache: ", current_cache)
response = ModelResponse()
await parallel_request_handler.async_post_call_success_hook(
data={}, user_api_key_dict=user_api_key_dict, response=response
)
hidden_params = getattr(response, "_hidden_params", {}) or {}
print(hidden_params)
assert "additional_headers" in hidden_params
assert "x-ratelimit-limit-requests" in hidden_params["additional_headers"]
assert "x-ratelimit-remaining-requests" in hidden_params["additional_headers"]
assert "x-ratelimit-limit-tokens" in hidden_params["additional_headers"]
assert "x-ratelimit-remaining-tokens" in hidden_params["additional_headers"]