test(test_openai_endpoints.py): add concurrency testing for user defined rate limits on proxy

This commit is contained in:
Krrish Dholakia 2024-04-12 18:56:13 -07:00
parent c03b0bbb24
commit ea1574c160
6 changed files with 68 additions and 28 deletions

View file

@ -189,6 +189,9 @@ jobs:
-p 4000:4000 \
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
-e AZURE_API_KEY=$AZURE_API_KEY \
-e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \
-e REDIS_PORT=$REDIS_PORT \
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \

View file

@ -30,7 +30,7 @@ from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
import copy
from litellm._logging import verbose_router_logger
import logging
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
class Router:
@ -1295,6 +1295,8 @@ class Router:
min_timeout=self.retry_after,
)
await asyncio.sleep(timeout)
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
raise e # don't wait to retry if deployment hits user-defined rate-limit
elif hasattr(original_exception, "status_code") and litellm._should_retry(
status_code=original_exception.status_code
):
@ -2376,6 +2378,7 @@ class Router:
Filter out model in model group, if:
- model context window < message length
- filter models above rpm limits
- [TODO] function call and model doesn't support function calling
"""
verbose_router_logger.debug(
@ -2400,7 +2403,7 @@ class Router:
rpm_key = f"{model}:rpm:{current_minute}"
model_group_cache = (
self.cache.get_cache(key=rpm_key, local_only=True) or {}
) # check the redis + in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
for idx, deployment in enumerate(_returned_deployments):
# see if we have the info for this model
try:
@ -2436,23 +2439,24 @@ class Router:
self.cache.get_cache(key=model_id, local_only=True) or 0
)
### get usage based cache ###
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
if isinstance(model_group_cache, dict):
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
current_request = max(
current_request_cache_local, model_group_cache[model_id]
)
current_request = max(
current_request_cache_local, model_group_cache[model_id]
)
if (
isinstance(_litellm_params, dict)
and _litellm_params.get("rpm", None) is not None
):
if (
isinstance(_litellm_params["rpm"], int)
and _litellm_params["rpm"] <= current_request
isinstance(_litellm_params, dict)
and _litellm_params.get("rpm", None) is not None
):
invalid_model_indices.append(idx)
_rate_limit_error = True
continue
if (
isinstance(_litellm_params["rpm"], int)
and _litellm_params["rpm"] <= current_request
):
invalid_model_indices.append(idx)
_rate_limit_error = True
continue
if len(invalid_model_indices) == len(_returned_deployments):
"""

View file

@ -14,6 +14,7 @@ from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose, get_utc_datetime
from litellm.types.router import RouterErrors
class LowestTPMLoggingHandler_v2(CustomLogger):
@ -58,7 +59,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
current_minute = dt.strftime("%H-%M")
model_group = deployment.get("model_name", "")
rpm_key = f"{model_group}:rpm:{current_minute}"
result = await self.router_cache.async_increment_cache(key=rpm_key, value=1)
local_result = await self.router_cache.async_get_cache(
key=rpm_key, local_only=True
) # check local result first
deployment_rpm = None
if deployment_rpm is None:
@ -70,21 +73,43 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
if deployment_rpm is None:
deployment_rpm = float("inf")
if result is not None and result > deployment_rpm:
if local_result is not None and local_result >= deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, result
deployment_rpm, local_result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, result
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
local_result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = await self.router_cache.async_increment_cache(key=rpm_key, value=1)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return deployment
def log_success_event(self, kwargs, response_obj, start_time, end_time):

View file

@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Tuple, Literal
from pydantic import BaseModel, validator
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
import uuid
import uuid, enum
class ModelConfig(BaseModel):
@ -166,3 +166,11 @@ class Deployment(BaseModel):
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class RouterErrors(enum.Enum):
"""
Enum for router specific errors with common codes
"""
user_defined_ratelimit_error = "Deployment over user-defined ratelimit."

View file

@ -67,12 +67,12 @@ litellm_settings:
telemetry: False
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
# router_settings:
# routing_strategy: usage-based-routing-v2
# redis_host: os.environ/REDIS_HOST
# redis_password: os.environ/REDIS_PASSWORD
# redis_port: os.environ/REDIS_PORT
# enable_pre_call_checks: true
router_settings:
routing_strategy: usage-based-routing-v2
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
enable_pre_call_checks: true
general_settings:
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys

View file

@ -194,7 +194,7 @@ async def test_chat_completion():
await chat_completion(session=session, key=key_2)
@pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
# @pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
@pytest.mark.asyncio
async def test_chat_completion_ratelimit():
"""