mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(parallel_request_limiter.py): fix max parallel request limiter on retries
This commit is contained in:
parent
153ce0d085
commit
594ca947c8
4 changed files with 100 additions and 6 deletions
|
@ -79,6 +79,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||||
if max_parallel_requests is None:
|
if max_parallel_requests is None:
|
||||||
max_parallel_requests = sys.maxsize
|
max_parallel_requests = sys.maxsize
|
||||||
|
global_max_parallel_requests = data.get("metadata", {}).get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||||
if tpm_limit is None:
|
if tpm_limit is None:
|
||||||
tpm_limit = sys.maxsize
|
tpm_limit = sys.maxsize
|
||||||
|
@ -91,6 +94,24 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
# Setup values
|
# Setup values
|
||||||
# ------------
|
# ------------
|
||||||
|
|
||||||
|
if global_max_parallel_requests is not None:
|
||||||
|
# get value from cache
|
||||||
|
_key = "global_max_parallel_requests"
|
||||||
|
current_global_requests = await cache.async_get_cache(
|
||||||
|
key=_key, local_only=True
|
||||||
|
)
|
||||||
|
# check if below limit
|
||||||
|
if current_global_requests is None:
|
||||||
|
current_global_requests = 1
|
||||||
|
# if above -> raise error
|
||||||
|
if current_global_requests >= global_max_parallel_requests:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429, detail="Max parallel request limit reached."
|
||||||
|
)
|
||||||
|
# if below -> increment
|
||||||
|
else:
|
||||||
|
await cache.async_increment_cache(key=_key, value=1, local_only=True)
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
current_hour = datetime.now().strftime("%H")
|
current_hour = datetime.now().strftime("%H")
|
||||||
current_minute = datetime.now().strftime("%M")
|
current_minute = datetime.now().strftime("%M")
|
||||||
|
@ -207,6 +228,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||||
|
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
||||||
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
"user_api_key_user_id", None
|
"user_api_key_user_id", None
|
||||||
|
@ -222,6 +246,14 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
# Setup values
|
# Setup values
|
||||||
# ------------
|
# ------------
|
||||||
|
|
||||||
|
if global_max_parallel_requests is not None:
|
||||||
|
# get value from cache
|
||||||
|
_key = "global_max_parallel_requests"
|
||||||
|
# decrement
|
||||||
|
await self.user_api_key_cache.async_increment_cache(
|
||||||
|
key=_key, value=-1, local_only=True
|
||||||
|
)
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
current_hour = datetime.now().strftime("%H")
|
current_hour = datetime.now().strftime("%H")
|
||||||
current_minute = datetime.now().strftime("%M")
|
current_minute = datetime.now().strftime("%M")
|
||||||
|
@ -336,6 +368,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
|
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
|
||||||
|
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
user_api_key = (
|
user_api_key = (
|
||||||
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
|
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
|
||||||
)
|
)
|
||||||
|
@ -347,17 +382,26 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
return
|
return
|
||||||
|
|
||||||
## decrement call count if call failed
|
## decrement call count if call failed
|
||||||
if (
|
if "Max parallel request limit reached" in str(kwargs["exception"]):
|
||||||
hasattr(kwargs["exception"], "status_code")
|
|
||||||
and kwargs["exception"].status_code == 429
|
|
||||||
and "Max parallel request limit reached" in str(kwargs["exception"])
|
|
||||||
):
|
|
||||||
pass # ignore failed calls due to max limit being reached
|
pass # ignore failed calls due to max limit being reached
|
||||||
else:
|
else:
|
||||||
# ------------
|
# ------------
|
||||||
# Setup values
|
# Setup values
|
||||||
# ------------
|
# ------------
|
||||||
|
|
||||||
|
if global_max_parallel_requests is not None:
|
||||||
|
# get value from cache
|
||||||
|
_key = "global_max_parallel_requests"
|
||||||
|
current_global_requests = (
|
||||||
|
await self.user_api_key_cache.async_get_cache(
|
||||||
|
key=_key, local_only=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# decrement
|
||||||
|
await self.user_api_key_cache.async_increment_cache(
|
||||||
|
key=_key, value=-1, local_only=True
|
||||||
|
)
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
current_hour = datetime.now().strftime("%H")
|
current_hour = datetime.now().strftime("%H")
|
||||||
current_minute = datetime.now().strftime("%M")
|
current_minute = datetime.now().strftime("%M")
|
||||||
|
|
|
@ -2848,6 +2848,7 @@ class ProxyConfig:
|
||||||
"""
|
"""
|
||||||
Pull from DB, read general settings value
|
Pull from DB, read general settings value
|
||||||
"""
|
"""
|
||||||
|
global general_settings
|
||||||
if db_general_settings is None:
|
if db_general_settings is None:
|
||||||
return
|
return
|
||||||
_general_settings = dict(db_general_settings)
|
_general_settings = dict(db_general_settings)
|
||||||
|
@ -3690,6 +3691,9 @@ async def chat_completion(
|
||||||
data["metadata"]["user_api_key_alias"] = getattr(
|
data["metadata"]["user_api_key_alias"] = getattr(
|
||||||
user_api_key_dict, "key_alias", None
|
user_api_key_dict, "key_alias", None
|
||||||
)
|
)
|
||||||
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
|
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
|
||||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
|
@ -3957,6 +3961,9 @@ async def completion(
|
||||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
user_api_key_dict, "team_id", None
|
user_api_key_dict, "team_id", None
|
||||||
)
|
)
|
||||||
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
data["metadata"]["user_api_key_team_alias"] = getattr(
|
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||||
user_api_key_dict, "team_alias", None
|
user_api_key_dict, "team_alias", None
|
||||||
)
|
)
|
||||||
|
@ -4151,6 +4158,9 @@ async def embeddings(
|
||||||
data["metadata"]["user_api_key_alias"] = getattr(
|
data["metadata"]["user_api_key_alias"] = getattr(
|
||||||
user_api_key_dict, "key_alias", None
|
user_api_key_dict, "key_alias", None
|
||||||
)
|
)
|
||||||
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
user_api_key_dict, "team_id", None
|
user_api_key_dict, "team_id", None
|
||||||
|
@ -4349,6 +4359,9 @@ async def image_generation(
|
||||||
data["metadata"]["user_api_key_alias"] = getattr(
|
data["metadata"]["user_api_key_alias"] = getattr(
|
||||||
user_api_key_dict, "key_alias", None
|
user_api_key_dict, "key_alias", None
|
||||||
)
|
)
|
||||||
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
user_api_key_dict, "team_id", None
|
user_api_key_dict, "team_id", None
|
||||||
|
@ -4529,6 +4542,9 @@ async def audio_transcriptions(
|
||||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
user_api_key_dict, "team_id", None
|
user_api_key_dict, "team_id", None
|
||||||
)
|
)
|
||||||
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
data["metadata"]["user_api_key_team_alias"] = getattr(
|
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||||
user_api_key_dict, "team_alias", None
|
user_api_key_dict, "team_alias", None
|
||||||
)
|
)
|
||||||
|
@ -4726,6 +4742,9 @@ async def moderations(
|
||||||
"authorization", None
|
"authorization", None
|
||||||
) # do not store the original `sk-..` api key in the db
|
) # do not store the original `sk-..` api key in the db
|
||||||
data["metadata"]["headers"] = _headers
|
data["metadata"]["headers"] = _headers
|
||||||
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
data["metadata"]["user_api_key_alias"] = getattr(
|
data["metadata"]["user_api_key_alias"] = getattr(
|
||||||
user_api_key_dict, "key_alias", None
|
user_api_key_dict, "key_alias", None
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,6 +28,37 @@ from datetime import datetime
|
||||||
## On Request failure
|
## On Request failure
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_global_max_parallel_requests():
|
||||||
|
"""
|
||||||
|
Test if ParallelRequestHandler respects 'global_max_parallel_requests'
|
||||||
|
|
||||||
|
data["metadata"]["global_max_parallel_requests"]
|
||||||
|
"""
|
||||||
|
global_max_parallel_requests = 0
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
_api_key = hash_token("sk-12345")
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=100)
|
||||||
|
local_cache = DualCache()
|
||||||
|
parallel_request_handler = MaxParallelRequestsHandler()
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={
|
||||||
|
"metadata": {
|
||||||
|
"global_max_parallel_requests": global_max_parallel_requests
|
||||||
|
}
|
||||||
|
},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
pytest.fail("Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pre_call_hook():
|
async def test_pre_call_hook():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -2579,7 +2579,7 @@ class Logging:
|
||||||
response_obj=result,
|
response_obj=result,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
) # type: ignore
|
||||||
if callable(callback): # custom logger functions
|
if callable(callback): # custom logger functions
|
||||||
await customLogger.async_log_event(
|
await customLogger.async_log_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue