diff --git a/.circleci/config.yml b/.circleci/config.yml index aaad8df77..282202176 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -189,6 +189,9 @@ jobs: -p 4000:4000 \ -e DATABASE_URL=$PROXY_DOCKER_DB_URL \ -e AZURE_API_KEY=$AZURE_API_KEY \ + -e REDIS_HOST=$REDIS_HOST \ + -e REDIS_PASSWORD=$REDIS_PASSWORD \ + -e REDIS_PORT=$REDIS_PORT \ -e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \ -e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ diff --git a/litellm/router.py b/litellm/router.py index 9abf70956..bbc5b87dc 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -30,7 +30,7 @@ from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime import copy from litellm._logging import verbose_router_logger import logging -from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params +from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors class Router: @@ -1295,6 +1295,8 @@ class Router: min_timeout=self.retry_after, ) await asyncio.sleep(timeout) + elif RouterErrors.user_defined_ratelimit_error.value in str(e): + raise e # don't wait to retry if deployment hits user-defined rate-limit elif hasattr(original_exception, "status_code") and litellm._should_retry( status_code=original_exception.status_code ): @@ -2376,6 +2378,7 @@ class Router: Filter out model in model group, if: - model context window < message length + - filter models above rpm limits - [TODO] function call and model doesn't support function calling """ verbose_router_logger.debug( @@ -2400,7 +2403,7 @@ class Router: rpm_key = f"{model}:rpm:{current_minute}" model_group_cache = ( self.cache.get_cache(key=rpm_key, local_only=True) or {} - ) # check the redis + in-memory cache used by lowest_latency and usage-based routing. Only check the local cache. + ) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache. for idx, deployment in enumerate(_returned_deployments): # see if we have the info for this model try: @@ -2436,23 +2439,24 @@ class Router: self.cache.get_cache(key=model_id, local_only=True) or 0 ) ### get usage based cache ### - model_group_cache[model_id] = model_group_cache.get(model_id, 0) + if isinstance(model_group_cache, dict): + model_group_cache[model_id] = model_group_cache.get(model_id, 0) - current_request = max( - current_request_cache_local, model_group_cache[model_id] - ) + current_request = max( + current_request_cache_local, model_group_cache[model_id] + ) - if ( - isinstance(_litellm_params, dict) - and _litellm_params.get("rpm", None) is not None - ): if ( - isinstance(_litellm_params["rpm"], int) - and _litellm_params["rpm"] <= current_request + isinstance(_litellm_params, dict) + and _litellm_params.get("rpm", None) is not None ): - invalid_model_indices.append(idx) - _rate_limit_error = True - continue + if ( + isinstance(_litellm_params["rpm"], int) + and _litellm_params["rpm"] <= current_request + ): + invalid_model_indices.append(idx) + _rate_limit_error = True + continue if len(invalid_model_indices) == len(_returned_deployments): """ diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 305f564aa..1c8738147 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -14,6 +14,7 @@ from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm._logging import verbose_router_logger from litellm.utils import print_verbose, get_utc_datetime +from litellm.types.router import RouterErrors class LowestTPMLoggingHandler_v2(CustomLogger): @@ -58,7 +59,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger): current_minute = dt.strftime("%H-%M") model_group = deployment.get("model_name", "") rpm_key = f"{model_group}:rpm:{current_minute}" - result = await self.router_cache.async_increment_cache(key=rpm_key, value=1) + local_result = await self.router_cache.async_get_cache( + key=rpm_key, local_only=True + ) # check local result first deployment_rpm = None if deployment_rpm is None: @@ -70,21 +73,43 @@ class LowestTPMLoggingHandler_v2(CustomLogger): if deployment_rpm is None: deployment_rpm = float("inf") - if result is not None and result > deployment_rpm: + if local_result is not None and local_result >= deployment_rpm: raise litellm.RateLimitError( message="Deployment over defined rpm limit={}. current usage={}".format( - deployment_rpm, result + deployment_rpm, local_result ), llm_provider="", model=deployment.get("litellm_params", {}).get("model"), response=httpx.Response( status_code=429, - content="Deployment over defined rpm limit={}. current usage={}".format( - deployment_rpm, result + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + local_result, ), request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore ), ) + else: + # if local result below limit, check redis ## prevent unnecessary redis checks + result = await self.router_cache.async_increment_cache(key=rpm_key, value=1) + if result is not None and result > deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + result, + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) return deployment def log_success_event(self, kwargs, response_obj, start_time, end_time): diff --git a/litellm/types/router.py b/litellm/types/router.py index 8afd575f3..8509099c6 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Tuple, Literal from pydantic import BaseModel, validator from .completion import CompletionRequest from .embedding import EmbeddingRequest -import uuid +import uuid, enum class ModelConfig(BaseModel): @@ -166,3 +166,11 @@ class Deployment(BaseModel): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) + + +class RouterErrors(enum.Enum): + """ + Enum for router specific errors with common codes + """ + + user_defined_ratelimit_error = "Deployment over user-defined ratelimit." diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 792e0c1f6..fa8c7fff7 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -67,12 +67,12 @@ litellm_settings: telemetry: False context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}] -# router_settings: -# routing_strategy: usage-based-routing-v2 -# redis_host: os.environ/REDIS_HOST -# redis_password: os.environ/REDIS_PASSWORD -# redis_port: os.environ/REDIS_PORT -# enable_pre_call_checks: true +router_settings: + routing_strategy: usage-based-routing-v2 + redis_host: os.environ/REDIS_HOST + redis_password: os.environ/REDIS_PASSWORD + redis_port: os.environ/REDIS_PORT + enable_pre_call_checks: true general_settings: master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 6fbdb7be5..28b7cde46 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -194,7 +194,7 @@ async def test_chat_completion(): await chat_completion(session=session, key=key_2) -@pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.") +# @pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.") @pytest.mark.asyncio async def test_chat_completion_ratelimit(): """