fix(batch_redis_get.py): handle custom namespace

Fix https://github.com/BerriAI/litellm/issues/5917
This commit is contained in:
Krrish Dholakia 2024-09-28 13:00:45 -07:00
parent e9e086a0b6
commit efc06d4a03
5 changed files with 105 additions and 37 deletions

View file

@ -2401,13 +2401,13 @@ class Cache:
# Hexadecimal representation of the hash
hash_hex = hash_object.hexdigest()
print_verbose(f"Hashed cache key (SHA-256): {hash_hex}")
if self.namespace is not None:
hash_hex = f"{self.namespace}:{hash_hex}"
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
elif kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
if kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
_namespace = kwargs.get("metadata", {}).get("redis_namespace", None)
hash_hex = f"{_namespace}:{hash_hex}"
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
elif self.namespace is not None:
hash_hex = f"{self.namespace}:{hash_hex}"
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
return hash_hex
def generate_streaming_content(self, content):

View file

@ -2141,7 +2141,7 @@ def _init_custom_logger_compatible_class(
llm_router: Optional[
Any
], # expect litellm.Router, but typing errors due to circular import
premium_user: bool = False,
premium_user: Optional[bool] = None,
) -> Optional[CustomLogger]:
if logging_integration == "lago":
for callback in _in_memory_loggers:
@ -2184,7 +2184,7 @@ def _init_custom_logger_compatible_class(
_prometheus_logger = PrometheusLogger()
_in_memory_loggers.append(_prometheus_logger)
return _prometheus_logger # type: ignore
else:
elif premium_user is False:
verbose_logger.warning(
f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}"
)

View file

@ -52,19 +52,10 @@ model_list:
model: "vertex_ai/gemini-flash-experimental"
litellm_settings:
callbacks: ["prometheus"]
redact_user_api_key_info: true
default_team_settings:
- team_id: "09ae376d-f6c8-42cd-88be-59717135684d" # team 1
success_callbacks: ["langfuse"]
langfuse_public_key: "pk-lf-1"
langfuse_secret: "sk-lf-1"
langfuse_host: ""
- team_id: "e5db79db-d623-4a5b-afd5-162be56074df" # team2
success_callback: ["langfuse"]
langfuse_public_key: "pk-lf-2"
langfuse_secret: "sk-lf-2"
langfuse_host: ""
json_logs: true
cache: true
cache_params:
type: "redis"
# namespace: "litellm_caching"
ttl: 900
callbacks: ["batch_redis_requests"]

View file

@ -3,14 +3,17 @@
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
### [BETA] this is in Beta. And might change.
from typing import Optional, Literal
import litellm
from litellm.caching import DualCache, RedisCache, InMemoryCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
import json
import traceback
from typing import Literal, Optional
from fastapi import HTTPException
import json, traceback
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_BatchRedisRequests(CustomLogger):
@ -18,6 +21,7 @@ class _PROXY_BatchRedisRequests(CustomLogger):
in_memory_cache: Optional[InMemoryCache] = None
def __init__(self):
if litellm.cache is not None:
litellm.cache.async_get_cache = (
self.async_get_cache
) # map the litellm 'get_cache' function to our custom function
@ -58,7 +62,7 @@ class _PROXY_BatchRedisRequests(CustomLogger):
if isinstance(key, str) and key.startswith(cache_key_name):
in_memory_cache_exists = True
if in_memory_cache_exists == False and litellm.cache is not None:
if in_memory_cache_exists is False and litellm.cache is not None:
"""
- Check if `litellm.Cache` is redis
- Get the relevant values
@ -105,16 +109,25 @@ class _PROXY_BatchRedisRequests(CustomLogger):
"""
- Check if the cache key is in-memory
- Else return None
- Else:
- add missing cache key from REDIS
- update in-memory cache
- return redis cache request
"""
try: # never block execution
cache_key: Optional[str] = None
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
elif litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(
*args, **kwargs
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
if cache_key is not None and self.in_memory_cache is not None:
if (
cache_key is not None
and self.in_memory_cache is not None
and litellm.cache is not None
):
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
@ -122,8 +135,16 @@ class _PROXY_BatchRedisRequests(CustomLogger):
cached_result = self.in_memory_cache.get_cache(
cache_key, *args, **kwargs
)
if cached_result is None:
cached_result = await litellm.cache.cache.async_get_cache(
cache_key, *args, **kwargs
)
if cached_result is not None:
await self.in_memory_cache.async_set_cache(
cache_key, cached_result, ttl=60
)
return litellm.cache._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception as e:
except Exception:
return None

View file

@ -2100,3 +2100,59 @@ async def test_redis_sentinel_caching():
print(f"stored_val: {stored_val}")
assert stored_val_2["id"] == response1.id
@pytest.mark.asyncio
async def test_redis_proxy_batch_redis_get_cache():
"""
Tests batch_redis_get.py
- make 1st call -> expect miss
- make 2nd call -> expect hit
"""
from litellm.caching import Cache, DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.hooks.batch_redis_get import _PROXY_BatchRedisRequests
litellm.cache = Cache(
type="redis",
host=os.getenv("REDIS_HOST"),
port=os.getenv("REDIS_PORT"),
password=os.getenv("REDIS_PASSWORD"),
namespace="test_namespace",
)
batch_redis_get_obj = (
_PROXY_BatchRedisRequests()
) # overrides the .async_get_cache method
user_api_key_cache = DualCache()
import uuid
batch_redis_get_obj.in_memory_cache = user_api_key_cache.in_memory_cache
messages = [{"role": "user", "content": "hi {}".format(uuid.uuid4())}]
# 1st call -> expect miss
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
mock_response="hello",
)
assert response is not None
assert "cache_key" not in response._hidden_params
print(response._hidden_params)
await asyncio.sleep(1)
# 2nd call -> expect hit
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
mock_response="hello",
)
print(response._hidden_params)
assert "cache_key" in response._hidden_params