test(test_openai_endpoints.py): add concurrency testing for user defined rate limits on proxy

This commit is contained in:
Krrish Dholakia 2024-04-12 18:56:13 -07:00
parent c03b0bbb24
commit ea1574c160
6 changed files with 68 additions and 28 deletions

View file

@ -189,6 +189,9 @@ jobs:
-p 4000:4000 \ -p 4000:4000 \
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \ -e DATABASE_URL=$PROXY_DOCKER_DB_URL \
-e AZURE_API_KEY=$AZURE_API_KEY \ -e AZURE_API_KEY=$AZURE_API_KEY \
-e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \
-e REDIS_PORT=$REDIS_PORT \
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \ -e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \ -e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \

View file

@ -30,7 +30,7 @@ from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
import copy import copy
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
import logging import logging
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
class Router: class Router:
@ -1295,6 +1295,8 @@ class Router:
min_timeout=self.retry_after, min_timeout=self.retry_after,
) )
await asyncio.sleep(timeout) await asyncio.sleep(timeout)
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
raise e # don't wait to retry if deployment hits user-defined rate-limit
elif hasattr(original_exception, "status_code") and litellm._should_retry( elif hasattr(original_exception, "status_code") and litellm._should_retry(
status_code=original_exception.status_code status_code=original_exception.status_code
): ):
@ -2376,6 +2378,7 @@ class Router:
Filter out model in model group, if: Filter out model in model group, if:
- model context window < message length - model context window < message length
- filter models above rpm limits
- [TODO] function call and model doesn't support function calling - [TODO] function call and model doesn't support function calling
""" """
verbose_router_logger.debug( verbose_router_logger.debug(
@ -2400,7 +2403,7 @@ class Router:
rpm_key = f"{model}:rpm:{current_minute}" rpm_key = f"{model}:rpm:{current_minute}"
model_group_cache = ( model_group_cache = (
self.cache.get_cache(key=rpm_key, local_only=True) or {} self.cache.get_cache(key=rpm_key, local_only=True) or {}
) # check the redis + in-memory cache used by lowest_latency and usage-based routing. Only check the local cache. ) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
for idx, deployment in enumerate(_returned_deployments): for idx, deployment in enumerate(_returned_deployments):
# see if we have the info for this model # see if we have the info for this model
try: try:
@ -2436,6 +2439,7 @@ class Router:
self.cache.get_cache(key=model_id, local_only=True) or 0 self.cache.get_cache(key=model_id, local_only=True) or 0
) )
### get usage based cache ### ### get usage based cache ###
if isinstance(model_group_cache, dict):
model_group_cache[model_id] = model_group_cache.get(model_id, 0) model_group_cache[model_id] = model_group_cache.get(model_id, 0)
current_request = max( current_request = max(

View file

@ -14,6 +14,7 @@ from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose, get_utc_datetime from litellm.utils import print_verbose, get_utc_datetime
from litellm.types.router import RouterErrors
class LowestTPMLoggingHandler_v2(CustomLogger): class LowestTPMLoggingHandler_v2(CustomLogger):
@ -58,7 +59,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime("%H-%M")
model_group = deployment.get("model_name", "") model_group = deployment.get("model_name", "")
rpm_key = f"{model_group}:rpm:{current_minute}" rpm_key = f"{model_group}:rpm:{current_minute}"
result = await self.router_cache.async_increment_cache(key=rpm_key, value=1) local_result = await self.router_cache.async_get_cache(
key=rpm_key, local_only=True
) # check local result first
deployment_rpm = None deployment_rpm = None
if deployment_rpm is None: if deployment_rpm is None:
@ -70,6 +73,26 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
if deployment_rpm is None: if deployment_rpm is None:
deployment_rpm = float("inf") deployment_rpm = float("inf")
if local_result is not None and local_result >= deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, local_result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
local_result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = await self.router_cache.async_increment_cache(key=rpm_key, value=1)
if result is not None and result > deployment_rpm: if result is not None and result > deployment_rpm:
raise litellm.RateLimitError( raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format( message="Deployment over defined rpm limit={}. current usage={}".format(
@ -79,8 +102,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
model=deployment.get("litellm_params", {}).get("model"), model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response( response=httpx.Response(
status_code=429, status_code=429,
content="Deployment over defined rpm limit={}. current usage={}".format( content="{} rpm limit={}. current usage={}".format(
deployment_rpm, result RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
result,
), ),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
), ),

View file

@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Tuple, Literal
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from .completion import CompletionRequest from .completion import CompletionRequest
from .embedding import EmbeddingRequest from .embedding import EmbeddingRequest
import uuid import uuid, enum
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
@ -166,3 +166,11 @@ class Deployment(BaseModel):
def __setitem__(self, key, value): def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes # Allow dictionary-style assignment of attributes
setattr(self, key, value) setattr(self, key, value)
class RouterErrors(enum.Enum):
"""
Enum for router specific errors with common codes
"""
user_defined_ratelimit_error = "Deployment over user-defined ratelimit."

View file

@ -67,12 +67,12 @@ litellm_settings:
telemetry: False telemetry: False
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}] context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
# router_settings: router_settings:
# routing_strategy: usage-based-routing-v2 routing_strategy: usage-based-routing-v2
# redis_host: os.environ/REDIS_HOST redis_host: os.environ/REDIS_HOST
# redis_password: os.environ/REDIS_PASSWORD redis_password: os.environ/REDIS_PASSWORD
# redis_port: os.environ/REDIS_PORT redis_port: os.environ/REDIS_PORT
# enable_pre_call_checks: true enable_pre_call_checks: true
general_settings: general_settings:
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys

View file

@ -194,7 +194,7 @@ async def test_chat_completion():
await chat_completion(session=session, key=key_2) await chat_completion(session=session, key=key_2)
@pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.") # @pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_ratelimit(): async def test_chat_completion_ratelimit():
""" """