forked from phoenix/litellm-mirror
Merge pull request #4706 from BerriAI/litellm_retry_after
Return `retry-after` header for rate limited requests
This commit is contained in:
commit
d0fb685c56
7 changed files with 93 additions and 22 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1624,11 +1624,17 @@ class ProxyException(Exception):
|
|||
type: str,
|
||||
param: Optional[str],
|
||||
code: Optional[int],
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
self.message = message
|
||||
self.type = type
|
||||
self.param = param
|
||||
self.code = code
|
||||
if headers is not None:
|
||||
for k, v in headers.items():
|
||||
if not isinstance(v, str):
|
||||
headers[k] = str(v)
|
||||
self.headers = headers or {}
|
||||
|
||||
# rules for proxyExceptions
|
||||
# Litellm router.py returns "No healthy deployment available" when there are no deployments available
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import sys
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
@ -44,9 +44,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
if current is None:
|
||||
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
|
||||
# base case
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Max parallel request limit reached."
|
||||
)
|
||||
return self.raise_rate_limit_error()
|
||||
new_val = {
|
||||
"current_requests": 1,
|
||||
"current_tpm": 0,
|
||||
|
@ -73,8 +71,28 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"LiteLLM Rate Limit Handler: Crossed TPM, RPM Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}",
|
||||
headers={"retry-after": str(self.time_to_next_minute())},
|
||||
)
|
||||
|
||||
def time_to_next_minute(self) -> float:
|
||||
# Get the current time
|
||||
now = datetime.now()
|
||||
|
||||
# Calculate the next minute
|
||||
next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0)
|
||||
|
||||
# Calculate the difference in seconds
|
||||
seconds_to_next_minute = (next_minute - now).total_seconds()
|
||||
|
||||
return seconds_to_next_minute
|
||||
|
||||
def raise_rate_limit_error(self) -> HTTPException:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Max parallel request limit reached.",
|
||||
headers={"retry-after": str(self.time_to_next_minute())},
|
||||
)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -112,9 +130,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
current_global_requests = 1
|
||||
# if above -> raise error
|
||||
if current_global_requests >= global_max_parallel_requests:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Max parallel request limit reached."
|
||||
)
|
||||
return self.raise_rate_limit_error()
|
||||
# if below -> increment
|
||||
else:
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
|
@ -142,9 +158,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
):
|
||||
pass
|
||||
elif max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Max parallel request limit reached."
|
||||
)
|
||||
return self.raise_rate_limit_error()
|
||||
elif current is None:
|
||||
new_val = {
|
||||
"current_requests": 1,
|
||||
|
@ -169,9 +183,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
request_count_api_key, new_val
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Max parallel request limit reached."
|
||||
)
|
||||
return self.raise_rate_limit_error()
|
||||
|
||||
# check if REQUEST ALLOWED for user_id
|
||||
user_id = user_api_key_dict.user_id
|
||||
|
|
|
@ -329,6 +329,7 @@ class UserAPIKeyCacheTTLEnum(enum.Enum):
|
|||
@app.exception_handler(ProxyException)
|
||||
async def openai_exception_handler(request: Request, exc: ProxyException):
|
||||
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
|
||||
headers = exc.headers
|
||||
return JSONResponse(
|
||||
status_code=(
|
||||
int(exc.code) if exc.code else status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
@ -341,6 +342,7 @@ async def openai_exception_handler(request: Request, exc: ProxyException):
|
|||
"code": exc.code,
|
||||
}
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
|
@ -3003,11 +3005,13 @@ async def chat_completion(
|
|||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
# print("e.headers={}".format(e.headers))
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", str(e)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
headers=getattr(e, "headers", {}),
|
||||
)
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
# What this tests?
|
||||
## Unit Tests for the max parallel request limiter for the proxy
|
||||
|
||||
import sys, os, asyncio, time, random
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
@ -12,16 +17,18 @@ import os
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.proxy.utils import ProxyLogging, hash_token
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||
_PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler,
|
||||
)
|
||||
from datetime import datetime
|
||||
from litellm.proxy.utils import ProxyLogging, hash_token
|
||||
|
||||
## On Request received
|
||||
## On Request success
|
||||
|
@ -139,6 +146,50 @@ async def test_pre_call_hook_rpm_limits():
|
|||
assert e.status_code == 429
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook_rpm_limits_retry_after():
|
||||
"""
|
||||
Test if rate limit error, returns 'retry_after'
|
||||
"""
|
||||
_api_key = "sk-12345"
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
|
||||
)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = MaxParallelRequestsHandler(
|
||||
internal_usage_cache=local_cache
|
||||
)
|
||||
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||
)
|
||||
|
||||
kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}}
|
||||
|
||||
await parallel_request_handler.async_log_success_event(
|
||||
kwargs=kwargs,
|
||||
response_obj="",
|
||||
start_time="",
|
||||
end_time="",
|
||||
)
|
||||
|
||||
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||
|
||||
try:
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
pytest.fail(f"Expected call to fail")
|
||||
except Exception as e:
|
||||
assert e.status_code == 429
|
||||
assert hasattr(e, "headers")
|
||||
assert "retry-after" in e.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook_team_rpm_limits():
|
||||
"""
|
||||
|
@ -467,9 +518,10 @@ async def test_normal_router_call():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_router_tpm_limit():
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
import logging
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
verbose_proxy_logger.setLevel(level=logging.DEBUG)
|
||||
model_list = [
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue