feat(lowest_tpm_rpm_v2.py): move to using redis.incr and redis.mget for getting model usage from redis

makes routing work across multiple instances
This commit is contained in:
Krrish Dholakia 2024-04-10 14:56:23 -07:00
parent b2741933dc
commit 180cf9bd5c
5 changed files with 437 additions and 12 deletions

View file

@ -81,9 +81,29 @@ class InMemoryCache(BaseCache):
return cached_response return cached_response
return None return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs) return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs):
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
def flush_cache(self): def flush_cache(self):
self.cache_dict.clear() self.cache_dict.clear()
self.ttl_dict.clear() self.ttl_dict.clear()
@ -246,6 +266,19 @@ class RedisCache(BaseCache):
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer() await self.flush_cache_buffer()
async def async_increment(self, key, value: int, **kwargs):
_redis_client = self.init_async_client()
try:
async with _redis_client as redis_client:
await redis_client.incr(name=key, amount=value)
except Exception as e:
verbose_logger.error(
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
async def flush_cache_buffer(self): async def flush_cache_buffer(self):
print_verbose( print_verbose(
f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}" f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
@ -283,6 +316,32 @@ class RedisCache(BaseCache):
traceback.print_exc() traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
def batch_get_cache(self, key_list) -> dict:
"""
Use Redis for bulk read operations
"""
key_value_dict = {}
try:
_keys = []
for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
_keys.append(cache_key)
results = self.redis_client.mget(keys=_keys)
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
key_value_dict = dict(zip(key_list, results))
decoded_results = {
k.decode("utf-8"): self._get_cache_logic(v)
for k, v in key_value_dict.items()
}
return decoded_results
except Exception as e:
print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key) key = self.check_and_fix_namespace(key=key)
@ -301,7 +360,7 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}" f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
) )
async def async_get_cache_pipeline(self, key_list) -> dict: async def async_batch_get_cache(self, key_list) -> dict:
""" """
Use Redis for bulk read operations Use Redis for bulk read operations
""" """
@ -309,14 +368,11 @@ class RedisCache(BaseCache):
key_value_dict = {} key_value_dict = {}
try: try:
async with _redis_client as redis_client: async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe: _keys = []
# Queue the get operations in the pipeline for all keys.
for cache_key in key_list: for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key) cache_key = self.check_and_fix_namespace(key=cache_key)
pipe.get(cache_key) # Queue GET command in pipeline _keys.append(cache_key)
results = await redis_client.mget(keys=_keys)
# Execute the pipeline and await the results.
results = await pipe.execute()
# Associate the results back with their keys. # Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'. # 'results' is a list of values corresponding to the order of keys in 'key_list'.
@ -897,6 +953,39 @@ class DualCache(BaseCache):
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
"""
- for the none values in the result
- check the redis cache
"""
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key in redis_result:
self.in_memory_cache.set_cache(key, redis_result[key], **kwargs)
for key, value in redis_result.items():
result[sublist_keys.index(key)] = value
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
async def async_get_cache(self, key, local_only: bool = False, **kwargs): async def async_get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first # Try to fetch from in-memory cache first
try: try:
@ -930,6 +1019,50 @@ class DualCache(BaseCache):
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
async def async_batch_get_cache(
self, keys: list, local_only: bool = False, **kwargs
):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
keys, **kwargs
)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
"""
- for the none values in the result
- check the redis cache
"""
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, **kwargs
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key in redis_result:
await self.in_memory_cache.async_set_cache(
key, redis_result[key], **kwargs
)
sublist_dict = dict(zip(sublist_keys, redis_result))
for key, value in sublist_dict.items():
result[sublist_keys.index(key)] = value[key]
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
try: try:
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
@ -941,6 +1074,24 @@ class DualCache(BaseCache):
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc() traceback.print_exc()
async def async_increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
):
"""
Key - the key in cache
Value - int - the value you want to increment by
"""
try:
if self.in_memory_cache is not None:
await self.in_memory_cache.async_increment(key, value, **kwargs)
if self.redis_cache is not None and local_only == False:
await self.redis_cache.async_increment(key, value, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
def flush_cache(self): def flush_cache(self):
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache() self.in_memory_cache.flush_cache()

View file

@ -28,7 +28,7 @@ litellm_settings:
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
router_settings: router_settings:
routing_strategy: usage-based-routing routing_strategy: usage-based-routing-v2
redis_host: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com redis_host: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com
redis_password: madeBerri@992 redis_password: madeBerri@992
redis_port: 16337 redis_port: 16337

View file

@ -79,7 +79,7 @@ class _PROXY_BatchRedisRequests(CustomLogger):
self.print_verbose(f"redis keys: {keys}") self.print_verbose(f"redis keys: {keys}")
if len(keys) > 0: if len(keys) > 0:
key_value_dict = ( key_value_dict = (
await litellm.cache.cache.async_get_cache_pipeline( await litellm.cache.cache.async_batch_get_cache(
key_list=keys key_list=keys
) )
) )

View file

@ -21,6 +21,7 @@ from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.llms.custom_httpx.azure_dall_e_2 import ( from litellm.llms.custom_httpx.azure_dall_e_2 import (
CustomHTTPTransport, CustomHTTPTransport,
AsyncCustomHTTPTransport, AsyncCustomHTTPTransport,
@ -273,6 +274,12 @@ class Router:
) )
if isinstance(litellm.callbacks, list): if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
elif routing_strategy == "usage-based-routing-v2":
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
router_cache=self.cache, model_list=self.model_list
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
elif routing_strategy == "latency-based-routing": elif routing_strategy == "latency-based-routing":
self.lowestlatency_logger = LowestLatencyLoggingHandler( self.lowestlatency_logger = LowestLatencyLoggingHandler(
router_cache=self.cache, router_cache=self.cache,
@ -2506,7 +2513,16 @@ class Router:
messages=messages, messages=messages,
input=input, input=input,
) )
elif (
self.routing_strategy == "usage-based-routing-v2"
and self.lowesttpm_logger_v2 is not None
):
deployment = self.lowesttpm_logger_v2.get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments,
messages=messages,
input=input,
)
if deployment is None: if deployment is None:
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"

View file

@ -0,0 +1,258 @@
#### What this does ####
# identifies lowest tpm deployment
import dotenv, os, requests, random
from typing import Optional, Union, List, Dict
from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback, asyncio
from litellm import token_counter
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose
class LowestTPMLoggingHandler_v2(CustomLogger):
"""
Updated version of TPM/RPM Logging.
Meant to work across instances.
Caches individual models, not model_groups
Uses batch get (redis.mget)
Increments tpm/rpm limit using redis.incr
"""
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(self, router_cache: DualCache, model_list: list):
self.router_cache = router_cache
self.model_list = model_list
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM/RPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = response_obj["usage"]["total_tokens"]
# ------------
# Setup values
# ------------
current_minute = datetime.now().strftime("%H-%M")
tpm_key = f"{model_group}:tpm:{current_minute}"
rpm_key = f"{model_group}:rpm:{current_minute}"
# ------------
# Update usage
# ------------
## TPM
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(key=tpm_key, value=request_count_dict)
## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(key=rpm_key, value=request_count_dict)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM/RPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = response_obj["usage"]["total_tokens"]
# ------------
# Setup values
# ------------
current_minute = datetime.now().strftime("%H-%M")
tpm_key = f"{id}:tpm:{current_minute}"
rpm_key = f"{id}:rpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
## TPM
await self.router_cache.async_increment_cache(
key=tpm_key, value=total_tokens
)
## RPM
await self.router_cache.async_increment_cache(key=rpm_key, value=1)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
traceback.print_exc()
pass
async def async_get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
"""
Async implementation of get deployments.
Reduces time to retrieve the tpm/rpm values from cache
"""
pass
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
"""
Returns a deployment with the lowest TPM/RPM usage.
"""
# get list of potential deployments
verbose_router_logger.debug(
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
)
current_minute = datetime.now().strftime("%H-%M")
tpm_keys = []
rpm_keys = []
for m in healthy_deployments:
if isinstance(m, dict):
id = m.get("model_info", {}).get(
"id"
) # a deployment should always have an 'id'. this is set in router.py
tpm_key = "{}:tpm:{}".format(id, current_minute)
rpm_key = "{}:rpm:{}".format(id, current_minute)
tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key)
tpm_values = self.router_cache.batch_get_cache(
keys=tpm_keys
) # [1, 2, None, ..]
rpm_values = self.router_cache.batch_get_cache(
keys=rpm_keys
) # [1, 2, None, ..]
tpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(tpm_keys):
tpm_dict[tpm_keys[idx]] = tpm_values[idx]
rpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(rpm_keys):
rpm_dict[rpm_keys[idx]] = rpm_values[idx]
try:
input_tokens = token_counter(messages=messages, text=input)
except:
input_tokens = 0
verbose_router_logger.debug(f"input_tokens={input_tokens}")
# -----------------------
# Find lowest used model
# ----------------------
lowest_tpm = float("inf")
if tpm_dict is None: # base case - none of the deployments have been used
# initialize a tpm dict with {model_id: 0}
tpm_dict = {}
for deployment in healthy_deployments:
tpm_dict[deployment["model_info"]["id"]] = 0
else:
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in tpm_dict:
tpm_dict[d["model_info"]["id"]] = 0
all_deployments = tpm_dict
deployment = None
for item, item_tpm in all_deployments.items():
## get the item from model list
_deployment = None
for m in healthy_deployments:
if item == m["model_info"]["id"]:
_deployment = m
if _deployment is None:
continue # skip to next one
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = float("inf")
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = float("inf")
if item_tpm + input_tokens > _deployment_tpm:
continue
elif (rpm_dict is not None and item in rpm_dict) and (
rpm_dict[item] + 1 > _deployment_rpm
):
continue
elif item_tpm < lowest_tpm:
lowest_tpm = item_tpm
deployment = _deployment
print_verbose("returning picked lowest tpm/rpm deployment.")
return deployment