test(test_openai_endpoints.py): add concurrency testing for user defined rate limits on proxy

This commit is contained in:
Krrish Dholakia 2024-04-12 18:56:13 -07:00
parent c03b0bbb24
commit ea1574c160
6 changed files with 68 additions and 28 deletions

View file

@ -30,7 +30,7 @@ from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
import copy
from litellm._logging import verbose_router_logger
import logging
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
class Router:
@ -1295,6 +1295,8 @@ class Router:
min_timeout=self.retry_after,
)
await asyncio.sleep(timeout)
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
raise e # don't wait to retry if deployment hits user-defined rate-limit
elif hasattr(original_exception, "status_code") and litellm._should_retry(
status_code=original_exception.status_code
):
@ -2376,6 +2378,7 @@ class Router:
Filter out model in model group, if:
- model context window < message length
- filter models above rpm limits
- [TODO] function call and model doesn't support function calling
"""
verbose_router_logger.debug(
@ -2400,7 +2403,7 @@ class Router:
rpm_key = f"{model}:rpm:{current_minute}"
model_group_cache = (
self.cache.get_cache(key=rpm_key, local_only=True) or {}
) # check the redis + in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
for idx, deployment in enumerate(_returned_deployments):
# see if we have the info for this model
try:
@ -2436,23 +2439,24 @@ class Router:
self.cache.get_cache(key=model_id, local_only=True) or 0
)
### get usage based cache ###
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
if isinstance(model_group_cache, dict):
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
current_request = max(
current_request_cache_local, model_group_cache[model_id]
)
current_request = max(
current_request_cache_local, model_group_cache[model_id]
)
if (
isinstance(_litellm_params, dict)
and _litellm_params.get("rpm", None) is not None
):
if (
isinstance(_litellm_params["rpm"], int)
and _litellm_params["rpm"] <= current_request
isinstance(_litellm_params, dict)
and _litellm_params.get("rpm", None) is not None
):
invalid_model_indices.append(idx)
_rate_limit_error = True
continue
if (
isinstance(_litellm_params["rpm"], int)
and _litellm_params["rpm"] <= current_request
):
invalid_model_indices.append(idx)
_rate_limit_error = True
continue
if len(invalid_model_indices) == len(_returned_deployments):
"""