feat(lowest_tpm_rpm_v2.py): move to using redis.incr and redis.mget for getting model usage from redis

makes routing work across multiple instances
This commit is contained in:
Krrish Dholakia 2024-04-10 14:56:23 -07:00
parent 06a0ca1e80
commit 31e2d4e6d1
5 changed files with 437 additions and 12 deletions

View file

@ -81,9 +81,29 @@ class InMemoryCache(BaseCache):
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs):
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
def flush_cache(self):
self.cache_dict.clear()
self.ttl_dict.clear()
@ -246,6 +266,19 @@ class RedisCache(BaseCache):
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer()
async def async_increment(self, key, value: int, **kwargs):
_redis_client = self.init_async_client()
try:
async with _redis_client as redis_client:
await redis_client.incr(name=key, amount=value)
except Exception as e:
verbose_logger.error(
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
str(e),
value,
)
traceback.print_exc()
async def flush_cache_buffer(self):
print_verbose(
f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
@ -283,6 +316,32 @@ class RedisCache(BaseCache):
traceback.print_exc()
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
def batch_get_cache(self, key_list) -> dict:
"""
Use Redis for bulk read operations
"""
key_value_dict = {}
try:
_keys = []
for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
_keys.append(cache_key)
results = self.redis_client.mget(keys=_keys)
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
key_value_dict = dict(zip(key_list, results))
decoded_results = {
k.decode("utf-8"): self._get_cache_logic(v)
for k, v in key_value_dict.items()
}
return decoded_results
except Exception as e:
print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict
async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client()
key = self.check_and_fix_namespace(key=key)
@ -301,7 +360,7 @@ class RedisCache(BaseCache):
f"LiteLLM Caching: async get() - Got exception from REDIS: {str(e)}"
)
async def async_get_cache_pipeline(self, key_list) -> dict:
async def async_batch_get_cache(self, key_list) -> dict:
"""
Use Redis for bulk read operations
"""
@ -309,14 +368,11 @@ class RedisCache(BaseCache):
key_value_dict = {}
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
# Queue the get operations in the pipeline for all keys.
for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
pipe.get(cache_key) # Queue GET command in pipeline
# Execute the pipeline and await the results.
results = await pipe.execute()
_keys = []
for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
_keys.append(cache_key)
results = await redis_client.mget(keys=_keys)
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
@ -897,6 +953,39 @@ class DualCache(BaseCache):
except Exception as e:
traceback.print_exc()
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.batch_get_cache(keys, **kwargs)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
"""
- for the none values in the result
- check the redis cache
"""
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key in redis_result:
self.in_memory_cache.set_cache(key, redis_result[key], **kwargs)
for key, value in redis_result.items():
result[sublist_keys.index(key)] = value
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
# Try to fetch from in-memory cache first
try:
@ -930,6 +1019,50 @@ class DualCache(BaseCache):
except Exception as e:
traceback.print_exc()
async def async_batch_get_cache(
self, keys: list, local_only: bool = False, **kwargs
):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
keys, **kwargs
)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
"""
- for the none values in the result
- check the redis cache
"""
sublist_keys = [
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, **kwargs
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key in redis_result:
await self.in_memory_cache.async_set_cache(
key, redis_result[key], **kwargs
)
sublist_dict = dict(zip(sublist_keys, redis_result))
for key, value in sublist_dict.items():
result[sublist_keys.index(key)] = value[key]
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception as e:
traceback.print_exc()
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
try:
if self.in_memory_cache is not None:
@ -941,6 +1074,24 @@ class DualCache(BaseCache):
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
async def async_increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
):
"""
Key - the key in cache
Value - int - the value you want to increment by
"""
try:
if self.in_memory_cache is not None:
await self.in_memory_cache.async_increment(key, value, **kwargs)
if self.redis_cache is not None and local_only == False:
await self.redis_cache.async_increment(key, value, **kwargs)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
def flush_cache(self):
if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache()