forked from phoenix/litellm-mirror
test(test_openai_endpoints.py): add concurrency testing for user defined rate limits on proxy
This commit is contained in:
parent
c03b0bbb24
commit
ea1574c160
6 changed files with 68 additions and 28 deletions
|
@ -189,6 +189,9 @@ jobs:
|
||||||
-p 4000:4000 \
|
-p 4000:4000 \
|
||||||
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
|
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
|
||||||
-e AZURE_API_KEY=$AZURE_API_KEY \
|
-e AZURE_API_KEY=$AZURE_API_KEY \
|
||||||
|
-e REDIS_HOST=$REDIS_HOST \
|
||||||
|
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||||
|
-e REDIS_PORT=$REDIS_PORT \
|
||||||
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
||||||
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
||||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||||
|
|
|
@ -30,7 +30,7 @@ from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
|
||||||
import copy
|
import copy
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
import logging
|
import logging
|
||||||
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params
|
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
|
@ -1295,6 +1295,8 @@ class Router:
|
||||||
min_timeout=self.retry_after,
|
min_timeout=self.retry_after,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(timeout)
|
await asyncio.sleep(timeout)
|
||||||
|
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
|
||||||
|
raise e # don't wait to retry if deployment hits user-defined rate-limit
|
||||||
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
||||||
status_code=original_exception.status_code
|
status_code=original_exception.status_code
|
||||||
):
|
):
|
||||||
|
@ -2376,6 +2378,7 @@ class Router:
|
||||||
Filter out model in model group, if:
|
Filter out model in model group, if:
|
||||||
|
|
||||||
- model context window < message length
|
- model context window < message length
|
||||||
|
- filter models above rpm limits
|
||||||
- [TODO] function call and model doesn't support function calling
|
- [TODO] function call and model doesn't support function calling
|
||||||
"""
|
"""
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
|
@ -2400,7 +2403,7 @@ class Router:
|
||||||
rpm_key = f"{model}:rpm:{current_minute}"
|
rpm_key = f"{model}:rpm:{current_minute}"
|
||||||
model_group_cache = (
|
model_group_cache = (
|
||||||
self.cache.get_cache(key=rpm_key, local_only=True) or {}
|
self.cache.get_cache(key=rpm_key, local_only=True) or {}
|
||||||
) # check the redis + in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
|
) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
|
||||||
for idx, deployment in enumerate(_returned_deployments):
|
for idx, deployment in enumerate(_returned_deployments):
|
||||||
# see if we have the info for this model
|
# see if we have the info for this model
|
||||||
try:
|
try:
|
||||||
|
@ -2436,23 +2439,24 @@ class Router:
|
||||||
self.cache.get_cache(key=model_id, local_only=True) or 0
|
self.cache.get_cache(key=model_id, local_only=True) or 0
|
||||||
)
|
)
|
||||||
### get usage based cache ###
|
### get usage based cache ###
|
||||||
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
|
if isinstance(model_group_cache, dict):
|
||||||
|
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
|
||||||
|
|
||||||
current_request = max(
|
current_request = max(
|
||||||
current_request_cache_local, model_group_cache[model_id]
|
current_request_cache_local, model_group_cache[model_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
|
||||||
isinstance(_litellm_params, dict)
|
|
||||||
and _litellm_params.get("rpm", None) is not None
|
|
||||||
):
|
|
||||||
if (
|
if (
|
||||||
isinstance(_litellm_params["rpm"], int)
|
isinstance(_litellm_params, dict)
|
||||||
and _litellm_params["rpm"] <= current_request
|
and _litellm_params.get("rpm", None) is not None
|
||||||
):
|
):
|
||||||
invalid_model_indices.append(idx)
|
if (
|
||||||
_rate_limit_error = True
|
isinstance(_litellm_params["rpm"], int)
|
||||||
continue
|
and _litellm_params["rpm"] <= current_request
|
||||||
|
):
|
||||||
|
invalid_model_indices.append(idx)
|
||||||
|
_rate_limit_error = True
|
||||||
|
continue
|
||||||
|
|
||||||
if len(invalid_model_indices) == len(_returned_deployments):
|
if len(invalid_model_indices) == len(_returned_deployments):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -14,6 +14,7 @@ from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
from litellm.utils import print_verbose, get_utc_datetime
|
from litellm.utils import print_verbose, get_utc_datetime
|
||||||
|
from litellm.types.router import RouterErrors
|
||||||
|
|
||||||
|
|
||||||
class LowestTPMLoggingHandler_v2(CustomLogger):
|
class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
|
@ -58,7 +59,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
current_minute = dt.strftime("%H-%M")
|
current_minute = dt.strftime("%H-%M")
|
||||||
model_group = deployment.get("model_name", "")
|
model_group = deployment.get("model_name", "")
|
||||||
rpm_key = f"{model_group}:rpm:{current_minute}"
|
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||||
result = await self.router_cache.async_increment_cache(key=rpm_key, value=1)
|
local_result = await self.router_cache.async_get_cache(
|
||||||
|
key=rpm_key, local_only=True
|
||||||
|
) # check local result first
|
||||||
|
|
||||||
deployment_rpm = None
|
deployment_rpm = None
|
||||||
if deployment_rpm is None:
|
if deployment_rpm is None:
|
||||||
|
@ -70,21 +73,43 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
if deployment_rpm is None:
|
if deployment_rpm is None:
|
||||||
deployment_rpm = float("inf")
|
deployment_rpm = float("inf")
|
||||||
|
|
||||||
if result is not None and result > deployment_rpm:
|
if local_result is not None and local_result >= deployment_rpm:
|
||||||
raise litellm.RateLimitError(
|
raise litellm.RateLimitError(
|
||||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||||
deployment_rpm, result
|
deployment_rpm, local_result
|
||||||
),
|
),
|
||||||
llm_provider="",
|
llm_provider="",
|
||||||
model=deployment.get("litellm_params", {}).get("model"),
|
model=deployment.get("litellm_params", {}).get("model"),
|
||||||
response=httpx.Response(
|
response=httpx.Response(
|
||||||
status_code=429,
|
status_code=429,
|
||||||
content="Deployment over defined rpm limit={}. current usage={}".format(
|
content="{} rpm limit={}. current usage={}".format(
|
||||||
deployment_rpm, result
|
RouterErrors.user_defined_ratelimit_error.value,
|
||||||
|
deployment_rpm,
|
||||||
|
local_result,
|
||||||
),
|
),
|
||||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||||
|
result = await self.router_cache.async_increment_cache(key=rpm_key, value=1)
|
||||||
|
if result is not None and result > deployment_rpm:
|
||||||
|
raise litellm.RateLimitError(
|
||||||
|
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||||
|
deployment_rpm, result
|
||||||
|
),
|
||||||
|
llm_provider="",
|
||||||
|
model=deployment.get("litellm_params", {}).get("model"),
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=429,
|
||||||
|
content="{} rpm limit={}. current usage={}".format(
|
||||||
|
RouterErrors.user_defined_ratelimit_error.value,
|
||||||
|
deployment_rpm,
|
||||||
|
result,
|
||||||
|
),
|
||||||
|
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
return deployment
|
return deployment
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Tuple, Literal
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from .completion import CompletionRequest
|
from .completion import CompletionRequest
|
||||||
from .embedding import EmbeddingRequest
|
from .embedding import EmbeddingRequest
|
||||||
import uuid
|
import uuid, enum
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
|
@ -166,3 +166,11 @@ class Deployment(BaseModel):
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
# Allow dictionary-style assignment of attributes
|
# Allow dictionary-style assignment of attributes
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class RouterErrors(enum.Enum):
|
||||||
|
"""
|
||||||
|
Enum for router specific errors with common codes
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_defined_ratelimit_error = "Deployment over user-defined ratelimit."
|
||||||
|
|
|
@ -67,12 +67,12 @@ litellm_settings:
|
||||||
telemetry: False
|
telemetry: False
|
||||||
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
|
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
|
||||||
|
|
||||||
# router_settings:
|
router_settings:
|
||||||
# routing_strategy: usage-based-routing-v2
|
routing_strategy: usage-based-routing-v2
|
||||||
# redis_host: os.environ/REDIS_HOST
|
redis_host: os.environ/REDIS_HOST
|
||||||
# redis_password: os.environ/REDIS_PASSWORD
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
# redis_port: os.environ/REDIS_PORT
|
redis_port: os.environ/REDIS_PORT
|
||||||
# enable_pre_call_checks: true
|
enable_pre_call_checks: true
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||||
|
|
|
@ -194,7 +194,7 @@ async def test_chat_completion():
|
||||||
await chat_completion(session=session, key=key_2)
|
await chat_completion(session=session, key=key_2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
|
# @pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_ratelimit():
|
async def test_chat_completion_ratelimit():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue