forked from phoenix/litellm-mirror
feat(proxy_server.py): retry if virtual key is rate limited
currently for chat completions
This commit is contained in:
parent
f95458dad8
commit
ad55f4dbb5
4 changed files with 57 additions and 1 deletions
|
@ -71,7 +71,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
):
|
||||
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
|
||||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||
if max_parallel_requests is None:
|
||||
max_parallel_requests = sys.maxsize
|
||||
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||
if tpm_limit is None:
|
||||
tpm_limit = sys.maxsize
|
||||
|
@ -105,6 +107,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
and rpm_limit == sys.maxsize
|
||||
):
|
||||
pass
|
||||
elif max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Max parallel request limit reached."
|
||||
)
|
||||
elif current is None:
|
||||
new_val = {
|
||||
"current_requests": 1,
|
||||
|
|
|
@ -8,6 +8,7 @@ import hashlib, uuid
|
|||
import warnings
|
||||
import importlib
|
||||
import warnings
|
||||
import backoff
|
||||
|
||||
|
||||
def showwarning(message, category, filename, lineno, file=None, line=None):
|
||||
|
@ -2298,6 +2299,11 @@ def parse_cache_control(cache_control):
|
|||
return cache_dict
|
||||
|
||||
|
||||
def on_backoff(details):
|
||||
# The 'tries' key in the details dictionary contains the number of completed tries
|
||||
verbose_proxy_logger.debug(f"Backing off... this was attempt #{details['tries']}")
|
||||
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name
|
||||
|
@ -2613,6 +2619,19 @@ async def completion(
|
|||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["chat/completions"],
|
||||
) # azure compatible endpoint
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
Exception, # base exception to catch for the backoff
|
||||
max_tries=litellm.num_retries or 3, # maximum number of retries
|
||||
max_time=litellm.request_timeout or 60, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
giveup=lambda e: not (
|
||||
isinstance(e, ProxyException)
|
||||
and getattr(e, "message", None) is not None
|
||||
and isinstance(e.message, str)
|
||||
and "Max parallel request limit reached" in e.message
|
||||
), # the result of the logical expression is on the second position
|
||||
)
|
||||
async def chat_completion(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
|
|
|
@ -38,6 +38,8 @@ litellm_settings:
|
|||
drop_params: True
|
||||
max_budget: 100
|
||||
budget_duration: 30d
|
||||
num_retries: 5
|
||||
request_timeout: 600
|
||||
general_settings:
|
||||
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
|
||||
proxy_budget_rescheduler_min_time: 10
|
||||
|
|
|
@ -6,6 +6,7 @@ import asyncio, time
|
|||
import aiohttp
|
||||
from openai import AsyncOpenAI
|
||||
import sys, os
|
||||
from typing import Optional
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../")
|
||||
|
@ -19,6 +20,7 @@ async def generate_key(
|
|||
budget=None,
|
||||
budget_duration=None,
|
||||
models=["azure-models", "gpt-4", "dall-e-3"],
|
||||
max_parallel_requests: Optional[int] = None,
|
||||
):
|
||||
url = "http://0.0.0.0:4000/key/generate"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
|
@ -28,6 +30,7 @@ async def generate_key(
|
|||
"duration": None,
|
||||
"max_budget": budget,
|
||||
"budget_duration": budget_duration,
|
||||
"max_parallel_requests": max_parallel_requests,
|
||||
}
|
||||
|
||||
print(f"data: {data}")
|
||||
|
@ -524,3 +527,29 @@ async def test_key_info_spend_values_sagemaker():
|
|||
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
|
||||
assert rounded_key_info_spend > 0
|
||||
# assert rounded_response_cost == rounded_key_info_spend
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_rate_limit():
|
||||
"""
|
||||
Tests backoff/retry logic on parallel request error.
|
||||
- Create key with max parallel requests 0
|
||||
- run 2 requests -> both fail
|
||||
- Create key with max parallel request 1
|
||||
- run 2 requests
|
||||
- both should succeed
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
key_gen = await generate_key(session=session, i=0, max_parallel_requests=0)
|
||||
new_key = key_gen["key"]
|
||||
try:
|
||||
await chat_completion(session=session, key=new_key)
|
||||
pytest.fail(f"Expected this call to fail")
|
||||
except Exception as e:
|
||||
pass
|
||||
key_gen = await generate_key(session=session, i=0, max_parallel_requests=1)
|
||||
new_key = key_gen["key"]
|
||||
try:
|
||||
await chat_completion(session=session, key=new_key)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Expected this call to work - {str(e)}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue