mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
* use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * fix test_embedding_caching_azure_individual_items_reordered
150 lines
5.9 KiB
Python
150 lines
5.9 KiB
Python
# What this does?
|
|
## Gets a key's redis cache, and store it in memory for 1 minute.
|
|
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
|
|
### [BETA] this is in Beta. And might change.
|
|
|
|
import json
|
|
import traceback
|
|
from typing import Literal, Optional
|
|
|
|
from fastapi import HTTPException
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
|
|
|
|
class _PROXY_BatchRedisRequests(CustomLogger):
|
|
# Class variables or attributes
|
|
in_memory_cache: Optional[InMemoryCache] = None
|
|
|
|
def __init__(self):
|
|
if litellm.cache is not None:
|
|
litellm.cache.async_get_cache = (
|
|
self.async_get_cache
|
|
) # map the litellm 'get_cache' function to our custom function
|
|
|
|
def print_verbose(
|
|
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
|
|
):
|
|
if debug_level == "DEBUG":
|
|
verbose_proxy_logger.debug(print_statement)
|
|
elif debug_level == "INFO":
|
|
verbose_proxy_logger.debug(print_statement)
|
|
if litellm.set_verbose is True:
|
|
print(print_statement) # noqa
|
|
|
|
async def async_pre_call_hook(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
cache: DualCache,
|
|
data: dict,
|
|
call_type: str,
|
|
):
|
|
try:
|
|
"""
|
|
Get the user key
|
|
|
|
Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory
|
|
|
|
If no, then get relevant cache from redis
|
|
"""
|
|
api_key = user_api_key_dict.api_key
|
|
|
|
cache_key_name = f"litellm:{api_key}:{call_type}"
|
|
self.in_memory_cache = cache.in_memory_cache
|
|
|
|
key_value_dict = {}
|
|
in_memory_cache_exists = False
|
|
for key in cache.in_memory_cache.cache_dict.keys():
|
|
if isinstance(key, str) and key.startswith(cache_key_name):
|
|
in_memory_cache_exists = True
|
|
|
|
if in_memory_cache_exists is False and litellm.cache is not None:
|
|
"""
|
|
- Check if `litellm.Cache` is redis
|
|
- Get the relevant values
|
|
"""
|
|
if litellm.cache.type is not None and isinstance(
|
|
litellm.cache.cache, RedisCache
|
|
):
|
|
# Initialize an empty list to store the keys
|
|
keys = []
|
|
self.print_verbose(f"cache_key_name: {cache_key_name}")
|
|
# Use the SCAN iterator to fetch keys matching the pattern
|
|
keys = await litellm.cache.cache.async_scan_iter(
|
|
pattern=cache_key_name, count=100
|
|
)
|
|
# If you need the truly "last" based on time or another criteria,
|
|
# ensure your key naming or storage strategy allows this determination
|
|
# Here you would sort or filter the keys as needed based on your strategy
|
|
self.print_verbose(f"redis keys: {keys}")
|
|
if len(keys) > 0:
|
|
key_value_dict = (
|
|
await litellm.cache.cache.async_batch_get_cache(
|
|
key_list=keys
|
|
)
|
|
)
|
|
|
|
## Add to cache
|
|
if len(key_value_dict.items()) > 0:
|
|
await cache.in_memory_cache.async_set_cache_pipeline(
|
|
cache_list=list(key_value_dict.items()), ttl=60
|
|
)
|
|
## Set cache namespace if it's a miss
|
|
data["metadata"]["redis_namespace"] = cache_key_name
|
|
except HTTPException as e:
|
|
raise e
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
"litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
|
|
str(e)
|
|
)
|
|
)
|
|
verbose_proxy_logger.debug(traceback.format_exc())
|
|
|
|
async def async_get_cache(self, *args, **kwargs):
|
|
"""
|
|
- Check if the cache key is in-memory
|
|
|
|
- Else:
|
|
- add missing cache key from REDIS
|
|
- update in-memory cache
|
|
- return redis cache request
|
|
"""
|
|
try: # never block execution
|
|
cache_key: Optional[str] = None
|
|
if "cache_key" in kwargs:
|
|
cache_key = kwargs["cache_key"]
|
|
elif litellm.cache is not None:
|
|
cache_key = litellm.cache.get_cache_key(
|
|
*args, **kwargs
|
|
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
|
|
|
|
if (
|
|
cache_key is not None
|
|
and self.in_memory_cache is not None
|
|
and litellm.cache is not None
|
|
):
|
|
cache_control_args = kwargs.get("cache", {})
|
|
max_age = cache_control_args.get(
|
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
|
)
|
|
cached_result = self.in_memory_cache.get_cache(
|
|
cache_key, *args, **kwargs
|
|
)
|
|
if cached_result is None:
|
|
cached_result = await litellm.cache.cache.async_get_cache(
|
|
cache_key, *args, **kwargs
|
|
)
|
|
if cached_result is not None:
|
|
await self.in_memory_cache.async_set_cache(
|
|
cache_key, cached_result, ttl=60
|
|
)
|
|
return litellm.cache._get_cache_logic(
|
|
cached_result=cached_result, max_age=max_age
|
|
)
|
|
except Exception:
|
|
return None
|