mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(batch_redis_get.py): batch redis GET requests for a given key + call type
reduces the number of GET requests we're making in high-throughput scenarios
This commit is contained in:
parent
e033e84720
commit
226953e1d8
5 changed files with 189 additions and 5 deletions
124
litellm/proxy/hooks/batch_redis_get.py
Normal file
124
litellm/proxy/hooks/batch_redis_get.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
# What this does?
|
||||
## Gets a key's redis cache, and store it in memory for 1 minute.
|
||||
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
|
||||
### [BETA] this is in Beta. And might change.
|
||||
|
||||
from typing import Optional, Literal
|
||||
import litellm
|
||||
from litellm.caching import DualCache, RedisCache, InMemoryCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from fastapi import HTTPException
|
||||
import json, traceback
|
||||
|
||||
|
||||
class _PROXY_BatchRedisRequests(CustomLogger):
|
||||
# Class variables or attributes
|
||||
in_memory_cache: Optional[InMemoryCache] = None
|
||||
|
||||
def __init__(self):
|
||||
litellm.cache.async_get_cache = (
|
||||
self.async_get_cache
|
||||
) # map the litellm 'get_cache' function to our custom function
|
||||
|
||||
def print_verbose(
|
||||
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
|
||||
):
|
||||
if debug_level == "DEBUG":
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
elif debug_level == "INFO":
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
try:
|
||||
"""
|
||||
Get the user key
|
||||
|
||||
Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory
|
||||
|
||||
If no, then get relevant cache from redis
|
||||
"""
|
||||
api_key = user_api_key_dict.api_key
|
||||
|
||||
cache_key_name = f"litellm:{api_key}:{call_type}"
|
||||
self.in_memory_cache = cache.in_memory_cache
|
||||
|
||||
key_value_dict = {}
|
||||
in_memory_cache_exists = False
|
||||
for key in cache.in_memory_cache.cache_dict.keys():
|
||||
if isinstance(key, str) and key.startswith(cache_key_name):
|
||||
in_memory_cache_exists = True
|
||||
|
||||
if in_memory_cache_exists == False and litellm.cache is not None:
|
||||
"""
|
||||
- Check if `litellm.Cache` is redis
|
||||
- Get the relevant values
|
||||
"""
|
||||
if litellm.cache.type is not None and isinstance(
|
||||
litellm.cache.cache, RedisCache
|
||||
):
|
||||
# Initialize an empty list to store the keys
|
||||
keys = []
|
||||
self.print_verbose(f"cache_key_name: {cache_key_name}")
|
||||
# Use the SCAN iterator to fetch keys matching the pattern
|
||||
keys = await litellm.cache.cache.async_scan_iter(
|
||||
pattern=cache_key_name, count=100
|
||||
)
|
||||
# If you need the truly "last" based on time or another criteria,
|
||||
# ensure your key naming or storage strategy allows this determination
|
||||
# Here you would sort or filter the keys as needed based on your strategy
|
||||
self.print_verbose(f"redis keys: {keys}")
|
||||
if len(keys) > 0:
|
||||
key_value_dict = (
|
||||
await litellm.cache.cache.async_get_cache_pipeline(
|
||||
key_list=keys
|
||||
)
|
||||
)
|
||||
|
||||
## Add to cache
|
||||
for key, value in key_value_dict.items():
|
||||
_cache_key = f"{cache_key_name}:{key}"
|
||||
cache.in_memory_cache.cache_dict[_cache_key] = value
|
||||
|
||||
## Set cache namespace if it's a miss
|
||||
data["metadata"]["redis_namespace"] = cache_key_name
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_get_cache(self, *args, **kwargs):
|
||||
"""
|
||||
- Check if the cache key is in-memory
|
||||
|
||||
- Else return None
|
||||
"""
|
||||
try: # never block execution
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = litellm.cache.get_cache_key(
|
||||
*args, **kwargs
|
||||
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
|
||||
if cache_key is not None and self.in_memory_cache is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = self.in_memory_cache.get_cache(
|
||||
cache_key, *args, **kwargs
|
||||
)
|
||||
return litellm.cache._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception as e:
|
||||
return None
|
Loading…
Add table
Add a link
Reference in a new issue