diff --git a/.circleci/config.yml b/.circleci/config.yml index aaad8df77..282202176 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -189,6 +189,9 @@ jobs: -p 4000:4000 \ -e DATABASE_URL=$PROXY_DOCKER_DB_URL \ -e AZURE_API_KEY=$AZURE_API_KEY \ + -e REDIS_HOST=$REDIS_HOST \ + -e REDIS_PASSWORD=$REDIS_PASSWORD \ + -e REDIS_PORT=$REDIS_PORT \ -e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \ -e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ diff --git a/litellm/caching.py b/litellm/caching.py index 2401d9708..cdb98d790 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -98,11 +98,12 @@ class InMemoryCache(BaseCache): return_val.append(val) return return_val - async def async_increment(self, key, value: int, **kwargs): + async def async_increment(self, key, value: int, **kwargs) -> int: # get the value init_value = await self.async_get_cache(key=key) or 0 value = init_value + value await self.async_set_cache(key, value, **kwargs) + return value def flush_cache(self): self.cache_dict.clear() @@ -266,11 +267,12 @@ class RedisCache(BaseCache): if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: await self.flush_cache_buffer() - async def async_increment(self, key, value: int, **kwargs): + async def async_increment(self, key, value: int, **kwargs) -> int: _redis_client = self.init_async_client() try: async with _redis_client as redis_client: - await redis_client.incr(name=key, amount=value) + result = await redis_client.incr(name=key, amount=value) + return result except Exception as e: verbose_logger.error( "LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s", @@ -278,6 +280,7 @@ class RedisCache(BaseCache): value, ) traceback.print_exc() + raise e async def flush_cache_buffer(self): print_verbose( @@ -1076,21 +1079,29 @@ class DualCache(BaseCache): async def async_increment_cache( self, key, value: int, local_only: bool = False, **kwargs - ): + ) -> int: """ Key - the key in cache Value - int - the value you want to increment by + + Returns - int - the incremented value """ try: + result: int = value if self.in_memory_cache is not None: - await self.in_memory_cache.async_increment(key, value, **kwargs) + result = await self.in_memory_cache.async_increment( + key, value, **kwargs + ) if self.redis_cache is not None and local_only == False: - await self.redis_cache.async_increment(key, value, **kwargs) + result = await self.redis_cache.async_increment(key, value, **kwargs) + + return result except Exception as e: print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") traceback.print_exc() + raise e def flush_cache(self): if self.in_memory_cache is not None: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index cf318b581..6026fd834 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1836,6 +1836,9 @@ async def _run_background_health_check(): await asyncio.sleep(health_check_interval) +semaphore = asyncio.Semaphore(1) + + class ProxyConfig: """ Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic. @@ -2425,8 +2428,7 @@ class ProxyConfig: for k, v in router_settings.items(): if k in available_args: router_params[k] = v - - router = litellm.Router(**router_params) # type:ignore + router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore return router, model_list, general_settings async def add_deployment( @@ -3421,6 +3423,7 @@ async def chat_completion( ): global general_settings, user_debug, proxy_logging_obj, llm_model_list try: + # async with llm_router.sem data = {} body = await request.body() body_str = body.decode() @@ -3525,7 +3528,9 @@ async def chat_completion( tasks = [] tasks.append( proxy_logging_obj.during_call_hook( - data=data, user_api_key_dict=user_api_key_dict, call_type="completion" + data=data, + user_api_key_dict=user_api_key_dict, + call_type="completion", ) ) diff --git a/litellm/router.py b/litellm/router.py index 8bb943af9..bbc5b87dc 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -30,7 +30,7 @@ from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime import copy from litellm._logging import verbose_router_logger import logging -from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params +from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors class Router: @@ -78,6 +78,7 @@ class Router: "latency-based-routing", ] = "simple-shuffle", routing_strategy_args: dict = {}, # just for latency-based routing + semaphore: Optional[asyncio.Semaphore] = None, ) -> None: """ Initialize the Router class with the given parameters for caching, reliability, and routing strategy. @@ -143,6 +144,8 @@ class Router: router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}]) ``` """ + if semaphore: + self.semaphore = semaphore self.set_verbose = set_verbose self.debug_level = debug_level self.enable_pre_call_checks = enable_pre_call_checks @@ -409,11 +412,18 @@ class Router: raise e async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): + """ + - Get an available deployment + - call it with a semaphore over the call + - semaphore specific to it's rpm + - in the semaphore, make a check against it's local rpm before running + """ model_name = None try: verbose_router_logger.debug( f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" ) + deployment = await self.async_get_available_deployment( model=model, messages=messages, @@ -443,6 +453,7 @@ class Router: potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) + # check if provided keys == client keys # dynamic_api_key = kwargs.get("api_key", None) if ( @@ -465,7 +476,7 @@ class Router: ) # this uses default_litellm_params when nothing is set ) - response = await litellm.acompletion( + _response = litellm.acompletion( **{ **data, "messages": messages, @@ -475,6 +486,25 @@ class Router: **kwargs, } ) + + rpm_semaphore = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="rpm_client" + ) + + if ( + rpm_semaphore is not None + and isinstance(rpm_semaphore, asyncio.Semaphore) + and self.routing_strategy == "usage-based-routing-v2" + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + """ + await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment) + response = await _response + else: + response = await _response + self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m" @@ -1265,6 +1295,8 @@ class Router: min_timeout=self.retry_after, ) await asyncio.sleep(timeout) + elif RouterErrors.user_defined_ratelimit_error.value in str(e): + raise e # don't wait to retry if deployment hits user-defined rate-limit elif hasattr(original_exception, "status_code") and litellm._should_retry( status_code=original_exception.status_code ): @@ -1680,12 +1712,26 @@ class Router: def set_client(self, model: dict): """ - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 + - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 + - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 """ client_ttl = self.client_ttl litellm_params = model.get("litellm_params", {}) model_name = litellm_params.get("model") model_id = model["model_info"]["id"] + # ### IF RPM SET - initialize a semaphore ### + rpm = litellm_params.get("rpm", None) + if rpm: + semaphore = asyncio.Semaphore(rpm) + cache_key = f"{model_id}_rpm_client" + self.cache.set_cache( + key=cache_key, + value=semaphore, + local_only=True, + ) + + # print("STORES SEMAPHORE IN CACHE") + #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## custom_llm_provider = litellm_params.get("custom_llm_provider") custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" @@ -2275,7 +2321,11 @@ class Router: The appropriate client based on the given client_type and kwargs. """ model_id = deployment["model_info"]["id"] - if client_type == "async": + if client_type == "rpm_client": + cache_key = "{}_rpm_client".format(model_id) + client = self.cache.get_cache(key=cache_key, local_only=True) + return client + elif client_type == "async": if kwargs.get("stream") == True: cache_key = f"{model_id}_stream_async_client" client = self.cache.get_cache(key=cache_key, local_only=True) @@ -2328,6 +2378,7 @@ class Router: Filter out model in model group, if: - model context window < message length + - filter models above rpm limits - [TODO] function call and model doesn't support function calling """ verbose_router_logger.debug( @@ -2352,7 +2403,7 @@ class Router: rpm_key = f"{model}:rpm:{current_minute}" model_group_cache = ( self.cache.get_cache(key=rpm_key, local_only=True) or {} - ) # check the redis + in-memory cache used by lowest_latency and usage-based routing. Only check the local cache. + ) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache. for idx, deployment in enumerate(_returned_deployments): # see if we have the info for this model try: @@ -2388,23 +2439,24 @@ class Router: self.cache.get_cache(key=model_id, local_only=True) or 0 ) ### get usage based cache ### - model_group_cache[model_id] = model_group_cache.get(model_id, 0) + if isinstance(model_group_cache, dict): + model_group_cache[model_id] = model_group_cache.get(model_id, 0) - current_request = max( - current_request_cache_local, model_group_cache[model_id] - ) + current_request = max( + current_request_cache_local, model_group_cache[model_id] + ) - if ( - isinstance(_litellm_params, dict) - and _litellm_params.get("rpm", None) is not None - ): if ( - isinstance(_litellm_params["rpm"], int) - and _litellm_params["rpm"] <= current_request + isinstance(_litellm_params, dict) + and _litellm_params.get("rpm", None) is not None ): - invalid_model_indices.append(idx) - _rate_limit_error = True - continue + if ( + isinstance(_litellm_params["rpm"], int) + and _litellm_params["rpm"] <= current_request + ): + invalid_model_indices.append(idx) + _rate_limit_error = True + continue if len(invalid_model_indices) == len(_returned_deployments): """ diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index c5598c11e..3babe0345 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -7,12 +7,14 @@ import datetime as datetime_og from datetime import datetime dotenv.load_dotenv() # Loading env variables using dotenv -import traceback, asyncio +import traceback, asyncio, httpx +import litellm from litellm import token_counter from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm._logging import verbose_router_logger from litellm.utils import print_verbose, get_utc_datetime +from litellm.types.router import RouterErrors class LowestTPMLoggingHandler_v2(CustomLogger): @@ -37,6 +39,86 @@ class LowestTPMLoggingHandler_v2(CustomLogger): self.router_cache = router_cache self.model_list = model_list + async def pre_call_rpm_check(self, deployment: dict) -> dict: + """ + Pre-call check + update model rpm + - Used inside semaphore + - raise rate limit error if deployment over limit + + Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994 + + Returns - deployment + + Raises - RateLimitError if deployment over defined RPM limit + """ + try: + + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + model_group = deployment.get("model_name", "") + rpm_key = f"{model_group}:rpm:{current_minute}" + local_result = await self.router_cache.async_get_cache( + key=rpm_key, local_only=True + ) # check local result first + + deployment_rpm = None + if deployment_rpm is None: + deployment_rpm = deployment.get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("litellm_params", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = deployment.get("model_info", {}).get("rpm") + if deployment_rpm is None: + deployment_rpm = float("inf") + + if local_result is not None and local_result >= deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, local_result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + local_result, + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + else: + # if local result below limit, check redis ## prevent unnecessary redis checks + result = await self.router_cache.async_increment_cache( + key=rpm_key, value=1 + ) + if result is not None and result > deployment_rpm: + raise litellm.RateLimitError( + message="Deployment over defined rpm limit={}. current usage={}".format( + deployment_rpm, result + ), + llm_provider="", + model=deployment.get("litellm_params", {}).get("model"), + response=httpx.Response( + status_code=429, + content="{} rpm limit={}. current usage={}".format( + RouterErrors.user_defined_ratelimit_error.value, + deployment_rpm, + result, + ), + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return deployment + except Exception as e: + if isinstance(e, litellm.RateLimitError): + raise e + return deployment # don't fail calls if eg. redis fails to connect + def log_success_event(self, kwargs, response_obj, start_time, end_time): try: """ @@ -91,7 +173,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: """ - Update TPM/RPM usage on success + Update TPM usage on success """ if kwargs["litellm_params"].get("metadata") is None: pass @@ -117,8 +199,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger): ) # use the same timezone regardless of system clock tpm_key = f"{id}:tpm:{current_minute}" - rpm_key = f"{id}:rpm:{current_minute}" - # ------------ # Update usage # ------------ @@ -128,8 +208,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger): await self.router_cache.async_increment_cache( key=tpm_key, value=total_tokens ) - ## RPM - await self.router_cache.async_increment_cache(key=rpm_key, value=1) ### TESTING ### if self.test_flag: diff --git a/litellm/types/router.py b/litellm/types/router.py index 8afd575f3..8509099c6 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Tuple, Literal from pydantic import BaseModel, validator from .completion import CompletionRequest from .embedding import EmbeddingRequest -import uuid +import uuid, enum class ModelConfig(BaseModel): @@ -166,3 +166,11 @@ class Deployment(BaseModel): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) + + +class RouterErrors(enum.Enum): + """ + Enum for router specific errors with common codes + """ + + user_defined_ratelimit_error = "Deployment over user-defined ratelimit." diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 792e0c1f6..fa8c7fff7 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -67,12 +67,12 @@ litellm_settings: telemetry: False context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}] -# router_settings: -# routing_strategy: usage-based-routing-v2 -# redis_host: os.environ/REDIS_HOST -# redis_password: os.environ/REDIS_PASSWORD -# redis_port: os.environ/REDIS_PORT -# enable_pre_call_checks: true +router_settings: + routing_strategy: usage-based-routing-v2 + redis_host: os.environ/REDIS_HOST + redis_password: os.environ/REDIS_PASSWORD + redis_port: os.environ/REDIS_PORT + enable_pre_call_checks: true general_settings: master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 6fbdb7be5..28b7cde46 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -194,7 +194,7 @@ async def test_chat_completion(): await chat_completion(session=session, key=key_2) -@pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.") +# @pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.") @pytest.mark.asyncio async def test_chat_completion_ratelimit(): """