feat(proxy_server.py): retry if virtual key is rate limited

currently for chat completions
This commit is contained in:
Krrish Dholakia 2024-03-05 19:00:03 -08:00
parent f95458dad8
commit ad55f4dbb5
4 changed files with 57 additions and 1 deletions

View file

@ -71,7 +71,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
max_parallel_requests = user_api_key_dict.max_parallel_requests
if max_parallel_requests is None:
max_parallel_requests = sys.maxsize
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
if tpm_limit is None:
tpm_limit = sys.maxsize
@ -105,6 +107,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
and rpm_limit == sys.maxsize
):
pass
elif max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
elif current is None:
new_val = {
"current_requests": 1,

View file

@ -8,6 +8,7 @@ import hashlib, uuid
import warnings
import importlib
import warnings
import backoff
def showwarning(message, category, filename, lineno, file=None, line=None):
@ -2298,6 +2299,11 @@ def parse_cache_control(cache_control):
return cache_dict
def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries
verbose_proxy_logger.debug(f"Backing off... this was attempt #{details['tries']}")
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name
@ -2613,6 +2619,19 @@ async def completion(
dependencies=[Depends(user_api_key_auth)],
tags=["chat/completions"],
) # azure compatible endpoint
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=litellm.num_retries or 3, # maximum number of retries
max_time=litellm.request_timeout or 60, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
giveup=lambda e: not (
isinstance(e, ProxyException)
and getattr(e, "message", None) is not None
and isinstance(e.message, str)
and "Max parallel request limit reached" in e.message
), # the result of the logical expression is on the second position
)
async def chat_completion(
request: Request,
fastapi_response: Response,

View file

@ -38,6 +38,8 @@ litellm_settings:
drop_params: True
max_budget: 100
budget_duration: 30d
num_retries: 5
request_timeout: 600
general_settings:
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
proxy_budget_rescheduler_min_time: 10

View file

@ -6,6 +6,7 @@ import asyncio, time
import aiohttp
from openai import AsyncOpenAI
import sys, os
from typing import Optional
sys.path.insert(
0, os.path.abspath("../")
@ -19,6 +20,7 @@ async def generate_key(
budget=None,
budget_duration=None,
models=["azure-models", "gpt-4", "dall-e-3"],
max_parallel_requests: Optional[int] = None,
):
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
@ -28,6 +30,7 @@ async def generate_key(
"duration": None,
"max_budget": budget,
"budget_duration": budget_duration,
"max_parallel_requests": max_parallel_requests,
}
print(f"data: {data}")
@ -524,3 +527,29 @@ async def test_key_info_spend_values_sagemaker():
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
assert rounded_key_info_spend > 0
# assert rounded_response_cost == rounded_key_info_spend
@pytest.mark.asyncio
async def test_key_rate_limit():
"""
Tests backoff/retry logic on parallel request error.
- Create key with max parallel requests 0
- run 2 requests -> both fail
- Create key with max parallel request 1
- run 2 requests
- both should succeed
"""
async with aiohttp.ClientSession() as session:
key_gen = await generate_key(session=session, i=0, max_parallel_requests=0)
new_key = key_gen["key"]
try:
await chat_completion(session=session, key=new_key)
pytest.fail(f"Expected this call to fail")
except Exception as e:
pass
key_gen = await generate_key(session=session, i=0, max_parallel_requests=1)
new_key = key_gen["key"]
try:
await chat_completion(session=session, key=new_key)
except Exception as e:
pytest.fail(f"Expected this call to work - {str(e)}")