forked from phoenix/litellm-mirror
test(test_openai_endpoints.py): add concurrency testing for user defined rate limits on proxy
This commit is contained in:
parent
c03b0bbb24
commit
ea1574c160
6 changed files with 68 additions and 28 deletions
|
@ -189,6 +189,9 @@ jobs:
|
|||
-p 4000:4000 \
|
||||
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
|
||||
-e AZURE_API_KEY=$AZURE_API_KEY \
|
||||
-e REDIS_HOST=$REDIS_HOST \
|
||||
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||
-e REDIS_PORT=$REDIS_PORT \
|
||||
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
||||
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||
|
|
|
@ -30,7 +30,7 @@ from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
|
|||
import copy
|
||||
from litellm._logging import verbose_router_logger
|
||||
import logging
|
||||
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params
|
||||
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
|
||||
|
||||
|
||||
class Router:
|
||||
|
@ -1295,6 +1295,8 @@ class Router:
|
|||
min_timeout=self.retry_after,
|
||||
)
|
||||
await asyncio.sleep(timeout)
|
||||
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
|
||||
raise e # don't wait to retry if deployment hits user-defined rate-limit
|
||||
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
||||
status_code=original_exception.status_code
|
||||
):
|
||||
|
@ -2376,6 +2378,7 @@ class Router:
|
|||
Filter out model in model group, if:
|
||||
|
||||
- model context window < message length
|
||||
- filter models above rpm limits
|
||||
- [TODO] function call and model doesn't support function calling
|
||||
"""
|
||||
verbose_router_logger.debug(
|
||||
|
@ -2400,7 +2403,7 @@ class Router:
|
|||
rpm_key = f"{model}:rpm:{current_minute}"
|
||||
model_group_cache = (
|
||||
self.cache.get_cache(key=rpm_key, local_only=True) or {}
|
||||
) # check the redis + in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
|
||||
) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
|
||||
for idx, deployment in enumerate(_returned_deployments):
|
||||
# see if we have the info for this model
|
||||
try:
|
||||
|
@ -2436,23 +2439,24 @@ class Router:
|
|||
self.cache.get_cache(key=model_id, local_only=True) or 0
|
||||
)
|
||||
### get usage based cache ###
|
||||
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
|
||||
if isinstance(model_group_cache, dict):
|
||||
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
|
||||
|
||||
current_request = max(
|
||||
current_request_cache_local, model_group_cache[model_id]
|
||||
)
|
||||
current_request = max(
|
||||
current_request_cache_local, model_group_cache[model_id]
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(_litellm_params, dict)
|
||||
and _litellm_params.get("rpm", None) is not None
|
||||
):
|
||||
if (
|
||||
isinstance(_litellm_params["rpm"], int)
|
||||
and _litellm_params["rpm"] <= current_request
|
||||
isinstance(_litellm_params, dict)
|
||||
and _litellm_params.get("rpm", None) is not None
|
||||
):
|
||||
invalid_model_indices.append(idx)
|
||||
_rate_limit_error = True
|
||||
continue
|
||||
if (
|
||||
isinstance(_litellm_params["rpm"], int)
|
||||
and _litellm_params["rpm"] <= current_request
|
||||
):
|
||||
invalid_model_indices.append(idx)
|
||||
_rate_limit_error = True
|
||||
continue
|
||||
|
||||
if len(invalid_model_indices) == len(_returned_deployments):
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@ from litellm.caching import DualCache
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.utils import print_verbose, get_utc_datetime
|
||||
from litellm.types.router import RouterErrors
|
||||
|
||||
|
||||
class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||
|
@ -58,7 +59,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
current_minute = dt.strftime("%H-%M")
|
||||
model_group = deployment.get("model_name", "")
|
||||
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||
result = await self.router_cache.async_increment_cache(key=rpm_key, value=1)
|
||||
local_result = await self.router_cache.async_get_cache(
|
||||
key=rpm_key, local_only=True
|
||||
) # check local result first
|
||||
|
||||
deployment_rpm = None
|
||||
if deployment_rpm is None:
|
||||
|
@ -70,21 +73,43 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
if deployment_rpm is None:
|
||||
deployment_rpm = float("inf")
|
||||
|
||||
if result is not None and result > deployment_rpm:
|
||||
if local_result is not None and local_result >= deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, result
|
||||
deployment_rpm, local_result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, result
|
||||
content="{} rpm limit={}. current usage={}".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
local_result,
|
||||
),
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
else:
|
||||
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||
result = await self.router_cache.async_increment_cache(key=rpm_key, value=1)
|
||||
if result is not None and result > deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="{} rpm limit={}. current usage={}".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
result,
|
||||
),
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return deployment
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Tuple, Literal
|
|||
from pydantic import BaseModel, validator
|
||||
from .completion import CompletionRequest
|
||||
from .embedding import EmbeddingRequest
|
||||
import uuid
|
||||
import uuid, enum
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
|
@ -166,3 +166,11 @@ class Deployment(BaseModel):
|
|||
def __setitem__(self, key, value):
|
||||
# Allow dictionary-style assignment of attributes
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class RouterErrors(enum.Enum):
|
||||
"""
|
||||
Enum for router specific errors with common codes
|
||||
"""
|
||||
|
||||
user_defined_ratelimit_error = "Deployment over user-defined ratelimit."
|
||||
|
|
|
@ -67,12 +67,12 @@ litellm_settings:
|
|||
telemetry: False
|
||||
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
|
||||
|
||||
# router_settings:
|
||||
# routing_strategy: usage-based-routing-v2
|
||||
# redis_host: os.environ/REDIS_HOST
|
||||
# redis_password: os.environ/REDIS_PASSWORD
|
||||
# redis_port: os.environ/REDIS_PORT
|
||||
# enable_pre_call_checks: true
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2
|
||||
redis_host: os.environ/REDIS_HOST
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
enable_pre_call_checks: true
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||
|
|
|
@ -194,7 +194,7 @@ async def test_chat_completion():
|
|||
await chat_completion(session=session, key=key_2)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
|
||||
# @pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_ratelimit():
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue