diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 4221b064e..8982e4e2b 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -71,7 +71,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ): self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") api_key = user_api_key_dict.api_key - max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize + max_parallel_requests = user_api_key_dict.max_parallel_requests + if max_parallel_requests is None: + max_parallel_requests = sys.maxsize tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) if tpm_limit is None: tpm_limit = sys.maxsize @@ -105,6 +107,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): and rpm_limit == sys.maxsize ): pass + elif max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0: + raise HTTPException( + status_code=429, detail="Max parallel request limit reached." + ) elif current is None: new_val = { "current_requests": 1, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 628f55852..ef54f29bd 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -8,6 +8,7 @@ import hashlib, uuid import warnings import importlib import warnings +import backoff def showwarning(message, category, filename, lineno, file=None, line=None): @@ -2298,6 +2299,11 @@ def parse_cache_control(cache_control): return cache_dict +def on_backoff(details): + # The 'tries' key in the details dictionary contains the number of completed tries + verbose_proxy_logger.debug(f"Backing off... this was attempt #{details['tries']}") + + @router.on_event("startup") async def startup_event(): global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name @@ -2613,6 +2619,19 @@ async def completion( dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"], ) # azure compatible endpoint +@backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=litellm.num_retries or 3, # maximum number of retries + max_time=litellm.request_timeout or 60, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + giveup=lambda e: not ( + isinstance(e, ProxyException) + and getattr(e, "message", None) is not None + and isinstance(e.message, str) + and "Max parallel request limit reached" in e.message + ), # the result of the logical expression is on the second position +) async def chat_completion( request: Request, fastapi_response: Response, diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 64183f216..4b454f5bd 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -38,6 +38,8 @@ litellm_settings: drop_params: True max_budget: 100 budget_duration: 30d + num_retries: 5 + request_timeout: 600 general_settings: master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) proxy_budget_rescheduler_min_time: 10 diff --git a/tests/test_keys.py b/tests/test_keys.py index 5a7b79e1c..413c24bc1 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -6,6 +6,7 @@ import asyncio, time import aiohttp from openai import AsyncOpenAI import sys, os +from typing import Optional sys.path.insert( 0, os.path.abspath("../") @@ -19,6 +20,7 @@ async def generate_key( budget=None, budget_duration=None, models=["azure-models", "gpt-4", "dall-e-3"], + max_parallel_requests: Optional[int] = None, ): url = "http://0.0.0.0:4000/key/generate" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} @@ -28,6 +30,7 @@ async def generate_key( "duration": None, "max_budget": budget, "budget_duration": budget_duration, + "max_parallel_requests": max_parallel_requests, } print(f"data: {data}") @@ -524,3 +527,29 @@ async def test_key_info_spend_values_sagemaker(): rounded_key_info_spend = round(key_info["info"]["spend"], 8) assert rounded_key_info_spend > 0 # assert rounded_response_cost == rounded_key_info_spend + + +@pytest.mark.asyncio +async def test_key_rate_limit(): + """ + Tests backoff/retry logic on parallel request error. + - Create key with max parallel requests 0 + - run 2 requests -> both fail + - Create key with max parallel request 1 + - run 2 requests + - both should succeed + """ + async with aiohttp.ClientSession() as session: + key_gen = await generate_key(session=session, i=0, max_parallel_requests=0) + new_key = key_gen["key"] + try: + await chat_completion(session=session, key=new_key) + pytest.fail(f"Expected this call to fail") + except Exception as e: + pass + key_gen = await generate_key(session=session, i=0, max_parallel_requests=1) + new_key = key_gen["key"] + try: + await chat_completion(session=session, key=new_key) + except Exception as e: + pytest.fail(f"Expected this call to work - {str(e)}")