Merge pull request #4706 from BerriAI/litellm_retry_after

Return `retry-after` header for rate limited requests
This commit is contained in:
Krish Dholakia 2024-07-13 21:37:41 -07:00 committed by GitHub
commit d0fb685c56
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 93 additions and 22 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1624,11 +1624,17 @@ class ProxyException(Exception):
type: str,
param: Optional[str],
code: Optional[int],
headers: Optional[Dict[str, str]] = None,
):
self.message = message
self.type = type
self.param = param
self.code = code
if headers is not None:
for k, v in headers.items():
if not isinstance(v, str):
headers[k] = str(v)
self.headers = headers or {}
# rules for proxyExceptions
# Litellm router.py returns "No healthy deployment available" when there are no deployments available

View file

@ -1,6 +1,6 @@
import sys
import traceback
from datetime import datetime
from datetime import datetime, timedelta
from typing import Optional
from fastapi import HTTPException
@ -44,9 +44,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if current is None:
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
# base case
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
return self.raise_rate_limit_error()
new_val = {
"current_requests": 1,
"current_tpm": 0,
@ -73,6 +71,26 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
raise HTTPException(
status_code=429,
detail=f"LiteLLM Rate Limit Handler: Crossed TPM, RPM Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}",
headers={"retry-after": str(self.time_to_next_minute())},
)
def time_to_next_minute(self) -> float:
# Get the current time
now = datetime.now()
# Calculate the next minute
next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0)
# Calculate the difference in seconds
seconds_to_next_minute = (next_minute - now).total_seconds()
return seconds_to_next_minute
def raise_rate_limit_error(self) -> HTTPException:
raise HTTPException(
status_code=429,
detail="Max parallel request limit reached.",
headers={"retry-after": str(self.time_to_next_minute())},
)
async def async_pre_call_hook(
@ -112,9 +130,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
current_global_requests = 1
# if above -> raise error
if current_global_requests >= global_max_parallel_requests:
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
return self.raise_rate_limit_error()
# if below -> increment
else:
await self.internal_usage_cache.async_increment_cache(
@ -142,9 +158,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
):
pass
elif max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
return self.raise_rate_limit_error()
elif current is None:
new_val = {
"current_requests": 1,
@ -169,9 +183,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
request_count_api_key, new_val
)
else:
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
return self.raise_rate_limit_error()
# check if REQUEST ALLOWED for user_id
user_id = user_api_key_dict.user_id

View file

@ -329,6 +329,7 @@ class UserAPIKeyCacheTTLEnum(enum.Enum):
@app.exception_handler(ProxyException)
async def openai_exception_handler(request: Request, exc: ProxyException):
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
headers = exc.headers
return JSONResponse(
status_code=(
int(exc.code) if exc.code else status.HTTP_500_INTERNAL_SERVER_ERROR
@ -341,6 +342,7 @@ async def openai_exception_handler(request: Request, exc: ProxyException):
"code": exc.code,
}
},
headers=headers,
)
@ -3003,11 +3005,13 @@ async def chat_completion(
router_model_names = llm_router.model_names if llm_router is not None else []
if isinstance(e, HTTPException):
# print("e.headers={}".format(e.headers))
raise ProxyException(
message=getattr(e, "detail", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
headers=getattr(e, "headers", {}),
)
error_msg = f"{str(e)}"
raise ProxyException(

View file

@ -1,9 +1,14 @@
# What this tests?
## Unit Tests for the max parallel request limiter for the proxy
import sys, os, asyncio, time, random
from datetime import datetime
import asyncio
import os
import random
import sys
import time
import traceback
from datetime import datetime
from dotenv import load_dotenv
load_dotenv()
@ -12,16 +17,18 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from datetime import datetime
import pytest
import litellm
from litellm import Router
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler,
)
from datetime import datetime
from litellm.proxy.utils import ProxyLogging, hash_token
## On Request received
## On Request success
@ -139,6 +146,50 @@ async def test_pre_call_hook_rpm_limits():
assert e.status_code == 429
@pytest.mark.asyncio
async def test_pre_call_hook_rpm_limits_retry_after():
"""
Test if rate limit error, returns 'retry_after'
"""
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler(
internal_usage_cache=local_cache
)
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}}
await parallel_request_handler.async_log_success_event(
kwargs=kwargs,
response_obj="",
start_time="",
end_time="",
)
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
try:
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={},
call_type="",
)
pytest.fail(f"Expected call to fail")
except Exception as e:
assert e.status_code == 429
assert hasattr(e, "headers")
assert "retry-after" in e.headers
@pytest.mark.asyncio
async def test_pre_call_hook_team_rpm_limits():
"""
@ -467,9 +518,10 @@ async def test_normal_router_call():
@pytest.mark.asyncio
async def test_normal_router_tpm_limit():
from litellm._logging import verbose_proxy_logger
import logging
from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.setLevel(level=logging.DEBUG)
model_list = [
{