fix(batch_redis_get.py): handle custom namespace

Fix https://github.com/BerriAI/litellm/issues/5917
This commit is contained in:
Krrish Dholakia 2024-09-28 13:00:45 -07:00
parent e9e086a0b6
commit efc06d4a03
5 changed files with 105 additions and 37 deletions

View file

@ -2401,13 +2401,13 @@ class Cache:
# Hexadecimal representation of the hash # Hexadecimal representation of the hash
hash_hex = hash_object.hexdigest() hash_hex = hash_object.hexdigest()
print_verbose(f"Hashed cache key (SHA-256): {hash_hex}") print_verbose(f"Hashed cache key (SHA-256): {hash_hex}")
if self.namespace is not None: if kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
hash_hex = f"{self.namespace}:{hash_hex}"
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
elif kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
_namespace = kwargs.get("metadata", {}).get("redis_namespace", None) _namespace = kwargs.get("metadata", {}).get("redis_namespace", None)
hash_hex = f"{_namespace}:{hash_hex}" hash_hex = f"{_namespace}:{hash_hex}"
print_verbose(f"Hashed Key with Namespace: {hash_hex}") print_verbose(f"Hashed Key with Namespace: {hash_hex}")
elif self.namespace is not None:
hash_hex = f"{self.namespace}:{hash_hex}"
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
return hash_hex return hash_hex
def generate_streaming_content(self, content): def generate_streaming_content(self, content):

View file

@ -2141,7 +2141,7 @@ def _init_custom_logger_compatible_class(
llm_router: Optional[ llm_router: Optional[
Any Any
], # expect litellm.Router, but typing errors due to circular import ], # expect litellm.Router, but typing errors due to circular import
premium_user: bool = False, premium_user: Optional[bool] = None,
) -> Optional[CustomLogger]: ) -> Optional[CustomLogger]:
if logging_integration == "lago": if logging_integration == "lago":
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
@ -2184,7 +2184,7 @@ def _init_custom_logger_compatible_class(
_prometheus_logger = PrometheusLogger() _prometheus_logger = PrometheusLogger()
_in_memory_loggers.append(_prometheus_logger) _in_memory_loggers.append(_prometheus_logger)
return _prometheus_logger # type: ignore return _prometheus_logger # type: ignore
else: elif premium_user is False:
verbose_logger.warning( verbose_logger.warning(
f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}" f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}"
) )

View file

@ -52,19 +52,10 @@ model_list:
model: "vertex_ai/gemini-flash-experimental" model: "vertex_ai/gemini-flash-experimental"
litellm_settings: litellm_settings:
callbacks: ["prometheus"] json_logs: true
redact_user_api_key_info: true cache: true
cache_params:
default_team_settings: type: "redis"
- team_id: "09ae376d-f6c8-42cd-88be-59717135684d" # team 1 # namespace: "litellm_caching"
success_callbacks: ["langfuse"] ttl: 900
langfuse_public_key: "pk-lf-1" callbacks: ["batch_redis_requests"]
langfuse_secret: "sk-lf-1"
langfuse_host: ""
- team_id: "e5db79db-d623-4a5b-afd5-162be56074df" # team2
success_callback: ["langfuse"]
langfuse_public_key: "pk-lf-2"
langfuse_secret: "sk-lf-2"
langfuse_host: ""

View file

@ -3,14 +3,17 @@
## This reduces the number of REDIS GET requests made during high-traffic by the proxy. ## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
### [BETA] this is in Beta. And might change. ### [BETA] this is in Beta. And might change.
from typing import Optional, Literal import json
import litellm import traceback
from litellm.caching import DualCache, RedisCache, InMemoryCache from typing import Literal, Optional
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException from fastapi import HTTPException
import json, traceback
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_BatchRedisRequests(CustomLogger): class _PROXY_BatchRedisRequests(CustomLogger):
@ -18,9 +21,10 @@ class _PROXY_BatchRedisRequests(CustomLogger):
in_memory_cache: Optional[InMemoryCache] = None in_memory_cache: Optional[InMemoryCache] = None
def __init__(self): def __init__(self):
litellm.cache.async_get_cache = ( if litellm.cache is not None:
self.async_get_cache litellm.cache.async_get_cache = (
) # map the litellm 'get_cache' function to our custom function self.async_get_cache
) # map the litellm 'get_cache' function to our custom function
def print_verbose( def print_verbose(
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG" self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
@ -58,7 +62,7 @@ class _PROXY_BatchRedisRequests(CustomLogger):
if isinstance(key, str) and key.startswith(cache_key_name): if isinstance(key, str) and key.startswith(cache_key_name):
in_memory_cache_exists = True in_memory_cache_exists = True
if in_memory_cache_exists == False and litellm.cache is not None: if in_memory_cache_exists is False and litellm.cache is not None:
""" """
- Check if `litellm.Cache` is redis - Check if `litellm.Cache` is redis
- Get the relevant values - Get the relevant values
@ -105,16 +109,25 @@ class _PROXY_BatchRedisRequests(CustomLogger):
""" """
- Check if the cache key is in-memory - Check if the cache key is in-memory
- Else return None - Else:
- add missing cache key from REDIS
- update in-memory cache
- return redis cache request
""" """
try: # never block execution try: # never block execution
cache_key: Optional[str] = None
if "cache_key" in kwargs: if "cache_key" in kwargs:
cache_key = kwargs["cache_key"] cache_key = kwargs["cache_key"]
else: elif litellm.cache is not None:
cache_key = litellm.cache.get_cache_key( cache_key = litellm.cache.get_cache_key(
*args, **kwargs *args, **kwargs
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic ) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
if cache_key is not None and self.in_memory_cache is not None:
if (
cache_key is not None
and self.in_memory_cache is not None
and litellm.cache is not None
):
cache_control_args = kwargs.get("cache", {}) cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get( max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf")) "s-max-age", cache_control_args.get("s-maxage", float("inf"))
@ -122,8 +135,16 @@ class _PROXY_BatchRedisRequests(CustomLogger):
cached_result = self.in_memory_cache.get_cache( cached_result = self.in_memory_cache.get_cache(
cache_key, *args, **kwargs cache_key, *args, **kwargs
) )
if cached_result is None:
cached_result = await litellm.cache.cache.async_get_cache(
cache_key, *args, **kwargs
)
if cached_result is not None:
await self.in_memory_cache.async_set_cache(
cache_key, cached_result, ttl=60
)
return litellm.cache._get_cache_logic( return litellm.cache._get_cache_logic(
cached_result=cached_result, max_age=max_age cached_result=cached_result, max_age=max_age
) )
except Exception as e: except Exception:
return None return None

View file

@ -2100,3 +2100,59 @@ async def test_redis_sentinel_caching():
print(f"stored_val: {stored_val}") print(f"stored_val: {stored_val}")
assert stored_val_2["id"] == response1.id assert stored_val_2["id"] == response1.id
@pytest.mark.asyncio
async def test_redis_proxy_batch_redis_get_cache():
"""
Tests batch_redis_get.py
- make 1st call -> expect miss
- make 2nd call -> expect hit
"""
from litellm.caching import Cache, DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.hooks.batch_redis_get import _PROXY_BatchRedisRequests
litellm.cache = Cache(
type="redis",
host=os.getenv("REDIS_HOST"),
port=os.getenv("REDIS_PORT"),
password=os.getenv("REDIS_PASSWORD"),
namespace="test_namespace",
)
batch_redis_get_obj = (
_PROXY_BatchRedisRequests()
) # overrides the .async_get_cache method
user_api_key_cache = DualCache()
import uuid
batch_redis_get_obj.in_memory_cache = user_api_key_cache.in_memory_cache
messages = [{"role": "user", "content": "hi {}".format(uuid.uuid4())}]
# 1st call -> expect miss
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
mock_response="hello",
)
assert response is not None
assert "cache_key" not in response._hidden_params
print(response._hidden_params)
await asyncio.sleep(1)
# 2nd call -> expect hit
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
mock_response="hello",
)
print(response._hidden_params)
assert "cache_key" in response._hidden_params