diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html deleted file mode 100644 index e9149499d..000000000 --- a/litellm/proxy/_experimental/out/404.html +++ /dev/null @@ -1 +0,0 @@ -404: This page could not be found.LiteLLM Dashboard

404

This page could not be found.

\ No newline at end of file diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html deleted file mode 100644 index 4e0dcb398..000000000 --- a/litellm/proxy/_experimental/out/model_hub.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index 6fec62da5..000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index f188801e0..7464714db 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1624,11 +1624,17 @@ class ProxyException(Exception): type: str, param: Optional[str], code: Optional[int], + headers: Optional[Dict[str, str]] = None, ): self.message = message self.type = type self.param = param self.code = code + if headers is not None: + for k, v in headers.items(): + if not isinstance(v, str): + headers[k] = str(v) + self.headers = headers or {} # rules for proxyExceptions # Litellm router.py returns "No healthy deployment available" when there are no deployments available diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index e9c8649d0..8a14b4ebe 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -1,6 +1,6 @@ import sys import traceback -from datetime import datetime +from datetime import datetime, timedelta from typing import Optional from fastapi import HTTPException @@ -44,9 +44,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): if current is None: if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: # base case - raise HTTPException( - status_code=429, detail="Max parallel request limit reached." - ) + return self.raise_rate_limit_error() new_val = { "current_requests": 1, "current_tpm": 0, @@ -73,8 +71,28 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): raise HTTPException( status_code=429, detail=f"LiteLLM Rate Limit Handler: Crossed TPM, RPM Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}", + headers={"retry-after": str(self.time_to_next_minute())}, ) + def time_to_next_minute(self) -> float: + # Get the current time + now = datetime.now() + + # Calculate the next minute + next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0) + + # Calculate the difference in seconds + seconds_to_next_minute = (next_minute - now).total_seconds() + + return seconds_to_next_minute + + def raise_rate_limit_error(self) -> HTTPException: + raise HTTPException( + status_code=429, + detail="Max parallel request limit reached.", + headers={"retry-after": str(self.time_to_next_minute())}, + ) + async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, @@ -112,9 +130,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): current_global_requests = 1 # if above -> raise error if current_global_requests >= global_max_parallel_requests: - raise HTTPException( - status_code=429, detail="Max parallel request limit reached." - ) + return self.raise_rate_limit_error() # if below -> increment else: await self.internal_usage_cache.async_increment_cache( @@ -142,9 +158,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ): pass elif max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: - raise HTTPException( - status_code=429, detail="Max parallel request limit reached." - ) + return self.raise_rate_limit_error() elif current is None: new_val = { "current_requests": 1, @@ -169,9 +183,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): request_count_api_key, new_val ) else: - raise HTTPException( - status_code=429, detail="Max parallel request limit reached." - ) + return self.raise_rate_limit_error() # check if REQUEST ALLOWED for user_id user_id = user_api_key_dict.user_id diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f2d59b15f..f05468d52 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -329,6 +329,7 @@ class UserAPIKeyCacheTTLEnum(enum.Enum): @app.exception_handler(ProxyException) async def openai_exception_handler(request: Request, exc: ProxyException): # NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions + headers = exc.headers return JSONResponse( status_code=( int(exc.code) if exc.code else status.HTTP_500_INTERNAL_SERVER_ERROR @@ -341,6 +342,7 @@ async def openai_exception_handler(request: Request, exc: ProxyException): "code": exc.code, } }, + headers=headers, ) @@ -3003,11 +3005,13 @@ async def chat_completion( router_model_names = llm_router.model_names if llm_router is not None else [] if isinstance(e, HTTPException): + # print("e.headers={}".format(e.headers)) raise ProxyException( message=getattr(e, "detail", str(e)), type=getattr(e, "type", "None"), param=getattr(e, "param", "None"), code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + headers=getattr(e, "headers", {}), ) error_msg = f"{str(e)}" raise ProxyException( diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index 06817fa87..df69686e1 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -1,9 +1,14 @@ # What this tests? ## Unit Tests for the max parallel request limiter for the proxy -import sys, os, asyncio, time, random -from datetime import datetime +import asyncio +import os +import random +import sys +import time import traceback +from datetime import datetime + from dotenv import load_dotenv load_dotenv() @@ -12,16 +17,18 @@ import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +from datetime import datetime + import pytest + import litellm from litellm import Router -from litellm.proxy.utils import ProxyLogging, hash_token -from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler as MaxParallelRequestsHandler, ) -from datetime import datetime +from litellm.proxy.utils import ProxyLogging, hash_token ## On Request received ## On Request success @@ -139,6 +146,50 @@ async def test_pre_call_hook_rpm_limits(): assert e.status_code == 429 +@pytest.mark.asyncio +async def test_pre_call_hook_rpm_limits_retry_after(): + """ + Test if rate limit error, returns 'retry_after' + """ + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, max_parallel_requests=1, tpm_limit=9, rpm_limit=1 + ) + local_cache = DualCache() + parallel_request_handler = MaxParallelRequestsHandler( + internal_usage_cache=local_cache + ) + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}} + + await parallel_request_handler.async_log_success_event( + kwargs=kwargs, + response_obj="", + start_time="", + end_time="", + ) + + ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} + + try: + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={}, + call_type="", + ) + + pytest.fail(f"Expected call to fail") + except Exception as e: + assert e.status_code == 429 + assert hasattr(e, "headers") + assert "retry-after" in e.headers + + @pytest.mark.asyncio async def test_pre_call_hook_team_rpm_limits(): """ @@ -467,9 +518,10 @@ async def test_normal_router_call(): @pytest.mark.asyncio async def test_normal_router_tpm_limit(): - from litellm._logging import verbose_proxy_logger import logging + from litellm._logging import verbose_proxy_logger + verbose_proxy_logger.setLevel(level=logging.DEBUG) model_list = [ {