diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html
deleted file mode 100644
index e9149499d..000000000
--- a/litellm/proxy/_experimental/out/404.html
+++ /dev/null
@@ -1 +0,0 @@
-
404: This page could not be found.LiteLLM Dashboard404
This page could not be found.
\ No newline at end of file
diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html
deleted file mode 100644
index 4e0dcb398..000000000
--- a/litellm/proxy/_experimental/out/model_hub.html
+++ /dev/null
@@ -1 +0,0 @@
-LiteLLM Dashboard
\ No newline at end of file
diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html
deleted file mode 100644
index 6fec62da5..000000000
--- a/litellm/proxy/_experimental/out/onboarding.html
+++ /dev/null
@@ -1 +0,0 @@
-LiteLLM Dashboard
\ No newline at end of file
diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py
index f188801e0..7464714db 100644
--- a/litellm/proxy/_types.py
+++ b/litellm/proxy/_types.py
@@ -1624,11 +1624,17 @@ class ProxyException(Exception):
type: str,
param: Optional[str],
code: Optional[int],
+ headers: Optional[Dict[str, str]] = None,
):
self.message = message
self.type = type
self.param = param
self.code = code
+ if headers is not None:
+ for k, v in headers.items():
+ if not isinstance(v, str):
+ headers[k] = str(v)
+ self.headers = headers or {}
# rules for proxyExceptions
# Litellm router.py returns "No healthy deployment available" when there are no deployments available
diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py
index e9c8649d0..8a14b4ebe 100644
--- a/litellm/proxy/hooks/parallel_request_limiter.py
+++ b/litellm/proxy/hooks/parallel_request_limiter.py
@@ -1,6 +1,6 @@
import sys
import traceback
-from datetime import datetime
+from datetime import datetime, timedelta
from typing import Optional
from fastapi import HTTPException
@@ -44,9 +44,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
if current is None:
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
# base case
- raise HTTPException(
- status_code=429, detail="Max parallel request limit reached."
- )
+ return self.raise_rate_limit_error()
new_val = {
"current_requests": 1,
"current_tpm": 0,
@@ -73,8 +71,28 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
raise HTTPException(
status_code=429,
detail=f"LiteLLM Rate Limit Handler: Crossed TPM, RPM Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}",
+ headers={"retry-after": str(self.time_to_next_minute())},
)
+ def time_to_next_minute(self) -> float:
+ # Get the current time
+ now = datetime.now()
+
+ # Calculate the next minute
+ next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0)
+
+ # Calculate the difference in seconds
+ seconds_to_next_minute = (next_minute - now).total_seconds()
+
+ return seconds_to_next_minute
+
+ def raise_rate_limit_error(self) -> HTTPException:
+ raise HTTPException(
+ status_code=429,
+ detail="Max parallel request limit reached.",
+ headers={"retry-after": str(self.time_to_next_minute())},
+ )
+
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
@@ -112,9 +130,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
current_global_requests = 1
# if above -> raise error
if current_global_requests >= global_max_parallel_requests:
- raise HTTPException(
- status_code=429, detail="Max parallel request limit reached."
- )
+ return self.raise_rate_limit_error()
# if below -> increment
else:
await self.internal_usage_cache.async_increment_cache(
@@ -142,9 +158,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
):
pass
elif max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
- raise HTTPException(
- status_code=429, detail="Max parallel request limit reached."
- )
+ return self.raise_rate_limit_error()
elif current is None:
new_val = {
"current_requests": 1,
@@ -169,9 +183,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
request_count_api_key, new_val
)
else:
- raise HTTPException(
- status_code=429, detail="Max parallel request limit reached."
- )
+ return self.raise_rate_limit_error()
# check if REQUEST ALLOWED for user_id
user_id = user_api_key_dict.user_id
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index f2d59b15f..f05468d52 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -329,6 +329,7 @@ class UserAPIKeyCacheTTLEnum(enum.Enum):
@app.exception_handler(ProxyException)
async def openai_exception_handler(request: Request, exc: ProxyException):
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
+ headers = exc.headers
return JSONResponse(
status_code=(
int(exc.code) if exc.code else status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -341,6 +342,7 @@ async def openai_exception_handler(request: Request, exc: ProxyException):
"code": exc.code,
}
},
+ headers=headers,
)
@@ -3003,11 +3005,13 @@ async def chat_completion(
router_model_names = llm_router.model_names if llm_router is not None else []
if isinstance(e, HTTPException):
+ # print("e.headers={}".format(e.headers))
raise ProxyException(
message=getattr(e, "detail", str(e)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ headers=getattr(e, "headers", {}),
)
error_msg = f"{str(e)}"
raise ProxyException(
diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py
index 06817fa87..df69686e1 100644
--- a/litellm/tests/test_parallel_request_limiter.py
+++ b/litellm/tests/test_parallel_request_limiter.py
@@ -1,9 +1,14 @@
# What this tests?
## Unit Tests for the max parallel request limiter for the proxy
-import sys, os, asyncio, time, random
-from datetime import datetime
+import asyncio
+import os
+import random
+import sys
+import time
import traceback
+from datetime import datetime
+
from dotenv import load_dotenv
load_dotenv()
@@ -12,16 +17,18 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
+from datetime import datetime
+
import pytest
+
import litellm
from litellm import Router
-from litellm.proxy.utils import ProxyLogging, hash_token
-from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
+from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler,
)
-from datetime import datetime
+from litellm.proxy.utils import ProxyLogging, hash_token
## On Request received
## On Request success
@@ -139,6 +146,50 @@ async def test_pre_call_hook_rpm_limits():
assert e.status_code == 429
+@pytest.mark.asyncio
+async def test_pre_call_hook_rpm_limits_retry_after():
+ """
+ Test if rate limit error, returns 'retry_after'
+ """
+ _api_key = "sk-12345"
+ user_api_key_dict = UserAPIKeyAuth(
+ api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1
+ )
+ local_cache = DualCache()
+ parallel_request_handler = MaxParallelRequestsHandler(
+ internal_usage_cache=local_cache
+ )
+
+ await parallel_request_handler.async_pre_call_hook(
+ user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
+ )
+
+ kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}}
+
+ await parallel_request_handler.async_log_success_event(
+ kwargs=kwargs,
+ response_obj="",
+ start_time="",
+ end_time="",
+ )
+
+ ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
+
+ try:
+ await parallel_request_handler.async_pre_call_hook(
+ user_api_key_dict=user_api_key_dict,
+ cache=local_cache,
+ data={},
+ call_type="",
+ )
+
+ pytest.fail(f"Expected call to fail")
+ except Exception as e:
+ assert e.status_code == 429
+ assert hasattr(e, "headers")
+ assert "retry-after" in e.headers
+
+
@pytest.mark.asyncio
async def test_pre_call_hook_team_rpm_limits():
"""
@@ -467,9 +518,10 @@ async def test_normal_router_call():
@pytest.mark.asyncio
async def test_normal_router_tpm_limit():
- from litellm._logging import verbose_proxy_logger
import logging
+ from litellm._logging import verbose_proxy_logger
+
verbose_proxy_logger.setLevel(level=logging.DEBUG)
model_list = [
{