diff --git a/litellm/caching.py b/litellm/caching.py index 6bf645f77b..f569a05087 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -81,9 +81,29 @@ class InMemoryCache(BaseCache): return cached_response return None + def batch_get_cache(self, keys: list, **kwargs): + return_val = [] + for k in keys: + val = self.get_cache(key=k, **kwargs) + return_val.append(val) + return return_val + async def async_get_cache(self, key, **kwargs): return self.get_cache(key=key, **kwargs) + async def async_batch_get_cache(self, keys: list, **kwargs): + return_val = [] + for k in keys: + val = self.get_cache(key=k, **kwargs) + return_val.append(val) + return return_val + + async def async_increment(self, key, value: int, **kwargs): + # get the value + init_value = await self.async_get_cache(key=key) or 0 + value = init_value + value + await self.async_set_cache(key, value, **kwargs) + def flush_cache(self): self.cache_dict.clear() self.ttl_dict.clear() @@ -246,6 +266,19 @@ class RedisCache(BaseCache): if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: await self.flush_cache_buffer() + async def async_increment(self, key, value: int, **kwargs): + _redis_client = self.init_async_client() + try: + async with _redis_client as redis_client: + await redis_client.incr(name=key, amount=value) + except Exception as e: + verbose_logger.error( + "LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + traceback.print_exc() + async def flush_cache_buffer(self): print_verbose( f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}" @@ -283,6 +316,32 @@ class RedisCache(BaseCache): traceback.print_exc() logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) + def batch_get_cache(self, key_list) -> dict: + """ + Use Redis for bulk read operations + """ + key_value_dict = {} + try: + _keys = [] + for cache_key in key_list: + cache_key = self.check_and_fix_namespace(key=cache_key) + _keys.append(cache_key) + results = self.redis_client.mget(keys=_keys) + + # Associate the results back with their keys. + # 'results' is a list of values corresponding to the order of keys in 'key_list'. + key_value_dict = dict(zip(key_list, results)) + + decoded_results = { + k.decode("utf-8"): self._get_cache_logic(v) + for k, v in key_value_dict.items() + } + + return decoded_results + except Exception as e: + print_verbose(f"Error occurred in pipeline read - {str(e)}") + return key_value_dict + async def async_get_cache(self, key, **kwargs): _redis_client = self.init_async_client() key = self.check_and_fix_namespace(key=key) @@ -301,7 +360,7 @@ class RedisCache(BaseCache): f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}" ) - async def async_get_cache_pipeline(self, key_list) -> dict: + async def async_batch_get_cache(self, key_list) -> dict: """ Use Redis for bulk read operations """ @@ -309,14 +368,11 @@ class RedisCache(BaseCache): key_value_dict = {} try: async with _redis_client as redis_client: - async with redis_client.pipeline(transaction=True) as pipe: - # Queue the get operations in the pipeline for all keys. - for cache_key in key_list: - cache_key = self.check_and_fix_namespace(key=cache_key) - pipe.get(cache_key) # Queue GET command in pipeline - - # Execute the pipeline and await the results. - results = await pipe.execute() + _keys = [] + for cache_key in key_list: + cache_key = self.check_and_fix_namespace(key=cache_key) + _keys.append(cache_key) + results = await redis_client.mget(keys=_keys) # Associate the results back with their keys. # 'results' is a list of values corresponding to the order of keys in 'key_list'. @@ -897,6 +953,39 @@ class DualCache(BaseCache): except Exception as e: traceback.print_exc() + def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs): + try: + result = [None for _ in range(len(keys))] + if self.in_memory_cache is not None: + in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs) + + print_verbose(f"in_memory_result: {in_memory_result}") + if in_memory_result is not None: + result = in_memory_result + + if None in result and self.redis_cache is not None and local_only == False: + """ + - for the none values in the result + - check the redis cache + """ + sublist_keys = [ + key for key, value in zip(keys, result) if value is None + ] + # If not found in in-memory cache, try fetching from Redis + redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs) + if redis_result is not None: + # Update in-memory cache with the value from Redis + for key in redis_result: + self.in_memory_cache.set_cache(key, redis_result[key], **kwargs) + + for key, value in redis_result.items(): + result[sublist_keys.index(key)] = value + + print_verbose(f"async batch get cache: cache result: {result}") + return result + except Exception as e: + traceback.print_exc() + async def async_get_cache(self, key, local_only: bool = False, **kwargs): # Try to fetch from in-memory cache first try: @@ -930,6 +1019,50 @@ class DualCache(BaseCache): except Exception as e: traceback.print_exc() + async def async_batch_get_cache( + self, keys: list, local_only: bool = False, **kwargs + ): + try: + result = [None for _ in range(len(keys))] + if self.in_memory_cache is not None: + in_memory_result = await self.in_memory_cache.async_batch_get_cache( + keys, **kwargs + ) + + print_verbose(f"in_memory_result: {in_memory_result}") + if in_memory_result is not None: + result = in_memory_result + + if None in result and self.redis_cache is not None and local_only == False: + """ + - for the none values in the result + - check the redis cache + """ + sublist_keys = [ + key for key, value in zip(keys, result) if value is None + ] + # If not found in in-memory cache, try fetching from Redis + redis_result = await self.redis_cache.async_batch_get_cache( + sublist_keys, **kwargs + ) + + if redis_result is not None: + # Update in-memory cache with the value from Redis + for key in redis_result: + await self.in_memory_cache.async_set_cache( + key, redis_result[key], **kwargs + ) + + sublist_dict = dict(zip(sublist_keys, redis_result)) + + for key, value in sublist_dict.items(): + result[sublist_keys.index(key)] = value[key] + + print_verbose(f"async batch get cache: cache result: {result}") + return result + except Exception as e: + traceback.print_exc() + async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): try: if self.in_memory_cache is not None: @@ -941,6 +1074,24 @@ class DualCache(BaseCache): print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") traceback.print_exc() + async def async_increment_cache( + self, key, value: int, local_only: bool = False, **kwargs + ): + """ + Key - the key in cache + + Value - int - the value you want to increment by + """ + try: + if self.in_memory_cache is not None: + await self.in_memory_cache.async_increment(key, value, **kwargs) + + if self.redis_cache is not None and local_only == False: + await self.redis_cache.async_increment(key, value, **kwargs) + except Exception as e: + print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") + traceback.print_exc() + def flush_cache(self): if self.in_memory_cache is not None: self.in_memory_cache.flush_cache() diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index ce6d543727..f2298ac7bb 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -28,7 +28,7 @@ litellm_settings: max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET router_settings: - routing_strategy: usage-based-routing + routing_strategy: usage-based-routing-v2 redis_host: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com redis_password: madeBerri@992 redis_port: 16337 diff --git a/litellm/proxy/hooks/batch_redis_get.py b/litellm/proxy/hooks/batch_redis_get.py index 71588c9d40..64541c1bff 100644 --- a/litellm/proxy/hooks/batch_redis_get.py +++ b/litellm/proxy/hooks/batch_redis_get.py @@ -79,7 +79,7 @@ class _PROXY_BatchRedisRequests(CustomLogger): self.print_verbose(f"redis keys: {keys}") if len(keys) > 0: key_value_dict = ( - await litellm.cache.cache.async_get_cache_pipeline( + await litellm.cache.cache.async_batch_get_cache( key_list=keys ) ) diff --git a/litellm/router.py b/litellm/router.py index 78f7faeec1..c6ac52bc8c 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -21,6 +21,7 @@ from collections import defaultdict from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler +from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 from litellm.llms.custom_httpx.azure_dall_e_2 import ( CustomHTTPTransport, AsyncCustomHTTPTransport, @@ -273,6 +274,12 @@ class Router: ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowesttpm_logger) # type: ignore + elif routing_strategy == "usage-based-routing-v2": + self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2( + router_cache=self.cache, model_list=self.model_list + ) + if isinstance(litellm.callbacks, list): + litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore elif routing_strategy == "latency-based-routing": self.lowestlatency_logger = LowestLatencyLoggingHandler( router_cache=self.cache, @@ -2506,7 +2513,16 @@ class Router: messages=messages, input=input, ) - + elif ( + self.routing_strategy == "usage-based-routing-v2" + and self.lowesttpm_logger_v2 is not None + ): + deployment = self.lowesttpm_logger_v2.get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, + messages=messages, + input=input, + ) if deployment is None: verbose_router_logger.info( f"get_available_deployment for model: {model}, No deployment available" diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py new file mode 100644 index 0000000000..991fd57c14 --- /dev/null +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -0,0 +1,258 @@ +#### What this does #### +# identifies lowest tpm deployment + +import dotenv, os, requests, random +from typing import Optional, Union, List, Dict +from datetime import datetime + +dotenv.load_dotenv() # Loading env variables using dotenv +import traceback, asyncio +from litellm import token_counter +from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm._logging import verbose_router_logger +from litellm.utils import print_verbose + + +class LowestTPMLoggingHandler_v2(CustomLogger): + """ + Updated version of TPM/RPM Logging. + + Meant to work across instances. + + Caches individual models, not model_groups + + Uses batch get (redis.mget) + + Increments tpm/rpm limit using redis.incr + """ + + test_flag: bool = False + logged_success: int = 0 + logged_failure: int = 0 + default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour + + def __init__(self, router_cache: DualCache, model_list: list): + self.router_cache = router_cache + self.model_list = model_list + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update TPM/RPM usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + total_tokens = response_obj["usage"]["total_tokens"] + + # ------------ + # Setup values + # ------------ + current_minute = datetime.now().strftime("%H-%M") + tpm_key = f"{model_group}:tpm:{current_minute}" + rpm_key = f"{model_group}:rpm:{current_minute}" + + # ------------ + # Update usage + # ------------ + + ## TPM + request_count_dict = self.router_cache.get_cache(key=tpm_key) or {} + request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens + + self.router_cache.set_cache(key=tpm_key, value=request_count_dict) + + ## RPM + request_count_dict = self.router_cache.get_cache(key=rpm_key) or {} + request_count_dict[id] = request_count_dict.get(id, 0) + 1 + + self.router_cache.set_cache(key=rpm_key, value=request_count_dict) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + traceback.print_exc() + pass + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + """ + Update TPM/RPM usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + total_tokens = response_obj["usage"]["total_tokens"] + + # ------------ + # Setup values + # ------------ + current_minute = datetime.now().strftime("%H-%M") + + tpm_key = f"{id}:tpm:{current_minute}" + rpm_key = f"{id}:rpm:{current_minute}" + + # ------------ + # Update usage + # ------------ + # update cache + + ## TPM + await self.router_cache.async_increment_cache( + key=tpm_key, value=total_tokens + ) + ## RPM + await self.router_cache.async_increment_cache(key=rpm_key, value=1) + + ### TESTING ### + if self.test_flag: + self.logged_success += 1 + except Exception as e: + traceback.print_exc() + pass + + async def async_get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + ): + """ + Async implementation of get deployments. + + Reduces time to retrieve the tpm/rpm values from cache + """ + pass + + def get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + ): + """ + Returns a deployment with the lowest TPM/RPM usage. + """ + # get list of potential deployments + verbose_router_logger.debug( + f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}" + ) + + current_minute = datetime.now().strftime("%H-%M") + tpm_keys = [] + rpm_keys = [] + for m in healthy_deployments: + if isinstance(m, dict): + id = m.get("model_info", {}).get( + "id" + ) # a deployment should always have an 'id'. this is set in router.py + tpm_key = "{}:tpm:{}".format(id, current_minute) + rpm_key = "{}:rpm:{}".format(id, current_minute) + + tpm_keys.append(tpm_key) + rpm_keys.append(rpm_key) + + tpm_values = self.router_cache.batch_get_cache( + keys=tpm_keys + ) # [1, 2, None, ..] + rpm_values = self.router_cache.batch_get_cache( + keys=rpm_keys + ) # [1, 2, None, ..] + + tpm_dict = {} # {model_id: 1, ..} + for idx, key in enumerate(tpm_keys): + tpm_dict[tpm_keys[idx]] = tpm_values[idx] + + rpm_dict = {} # {model_id: 1, ..} + for idx, key in enumerate(rpm_keys): + rpm_dict[rpm_keys[idx]] = rpm_values[idx] + + try: + input_tokens = token_counter(messages=messages, text=input) + except: + input_tokens = 0 + verbose_router_logger.debug(f"input_tokens={input_tokens}") + # ----------------------- + # Find lowest used model + # ---------------------- + lowest_tpm = float("inf") + + if tpm_dict is None: # base case - none of the deployments have been used + # initialize a tpm dict with {model_id: 0} + tpm_dict = {} + for deployment in healthy_deployments: + tpm_dict[deployment["model_info"]["id"]] = 0 + else: + for d in healthy_deployments: + ## if healthy deployment not yet used + if d["model_info"]["id"] not in tpm_dict: + tpm_dict[d["model_info"]["id"]] = 0 + + all_deployments = tpm_dict + + deployment = None + for item, item_tpm in all_deployments.items(): + ## get the item from model list + _deployment = None + for m in healthy_deployments: + if item == m["model_info"]["id"]: + _deployment = m + + if _deployment is None: + continue # skip to next one + + _deployment_tpm = None + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("tpm") + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("litellm_params", {}).get("tpm") + if _deployment_tpm is None: + _deployment_tpm = _deployment.get("model_info", {}).get("tpm") + if _deployment_tpm is None: + _deployment_tpm = float("inf") + + _deployment_rpm = None + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("rpm") + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("litellm_params", {}).get("rpm") + if _deployment_rpm is None: + _deployment_rpm = _deployment.get("model_info", {}).get("rpm") + if _deployment_rpm is None: + _deployment_rpm = float("inf") + + if item_tpm + input_tokens > _deployment_tpm: + continue + elif (rpm_dict is not None and item in rpm_dict) and ( + rpm_dict[item] + 1 > _deployment_rpm + ): + continue + elif item_tpm < lowest_tpm: + lowest_tpm = item_tpm + deployment = _deployment + print_verbose("returning picked lowest tpm/rpm deployment.") + return deployment