mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(batch_redis_get.py): handle custom namespace
Fix https://github.com/BerriAI/litellm/issues/5917
This commit is contained in:
parent
e9e086a0b6
commit
efc06d4a03
5 changed files with 105 additions and 37 deletions
|
@ -2401,13 +2401,13 @@ class Cache:
|
||||||
# Hexadecimal representation of the hash
|
# Hexadecimal representation of the hash
|
||||||
hash_hex = hash_object.hexdigest()
|
hash_hex = hash_object.hexdigest()
|
||||||
print_verbose(f"Hashed cache key (SHA-256): {hash_hex}")
|
print_verbose(f"Hashed cache key (SHA-256): {hash_hex}")
|
||||||
if self.namespace is not None:
|
if kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
|
||||||
hash_hex = f"{self.namespace}:{hash_hex}"
|
|
||||||
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
|
||||||
elif kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
|
|
||||||
_namespace = kwargs.get("metadata", {}).get("redis_namespace", None)
|
_namespace = kwargs.get("metadata", {}).get("redis_namespace", None)
|
||||||
hash_hex = f"{_namespace}:{hash_hex}"
|
hash_hex = f"{_namespace}:{hash_hex}"
|
||||||
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
||||||
|
elif self.namespace is not None:
|
||||||
|
hash_hex = f"{self.namespace}:{hash_hex}"
|
||||||
|
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
||||||
return hash_hex
|
return hash_hex
|
||||||
|
|
||||||
def generate_streaming_content(self, content):
|
def generate_streaming_content(self, content):
|
||||||
|
|
|
@ -2141,7 +2141,7 @@ def _init_custom_logger_compatible_class(
|
||||||
llm_router: Optional[
|
llm_router: Optional[
|
||||||
Any
|
Any
|
||||||
], # expect litellm.Router, but typing errors due to circular import
|
], # expect litellm.Router, but typing errors due to circular import
|
||||||
premium_user: bool = False,
|
premium_user: Optional[bool] = None,
|
||||||
) -> Optional[CustomLogger]:
|
) -> Optional[CustomLogger]:
|
||||||
if logging_integration == "lago":
|
if logging_integration == "lago":
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
|
@ -2184,7 +2184,7 @@ def _init_custom_logger_compatible_class(
|
||||||
_prometheus_logger = PrometheusLogger()
|
_prometheus_logger = PrometheusLogger()
|
||||||
_in_memory_loggers.append(_prometheus_logger)
|
_in_memory_loggers.append(_prometheus_logger)
|
||||||
return _prometheus_logger # type: ignore
|
return _prometheus_logger # type: ignore
|
||||||
else:
|
elif premium_user is False:
|
||||||
verbose_logger.warning(
|
verbose_logger.warning(
|
||||||
f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}"
|
f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -52,19 +52,10 @@ model_list:
|
||||||
model: "vertex_ai/gemini-flash-experimental"
|
model: "vertex_ai/gemini-flash-experimental"
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["prometheus"]
|
json_logs: true
|
||||||
redact_user_api_key_info: true
|
cache: true
|
||||||
|
cache_params:
|
||||||
default_team_settings:
|
type: "redis"
|
||||||
- team_id: "09ae376d-f6c8-42cd-88be-59717135684d" # team 1
|
# namespace: "litellm_caching"
|
||||||
success_callbacks: ["langfuse"]
|
ttl: 900
|
||||||
langfuse_public_key: "pk-lf-1"
|
callbacks: ["batch_redis_requests"]
|
||||||
langfuse_secret: "sk-lf-1"
|
|
||||||
langfuse_host: ""
|
|
||||||
|
|
||||||
- team_id: "e5db79db-d623-4a5b-afd5-162be56074df" # team2
|
|
||||||
success_callback: ["langfuse"]
|
|
||||||
langfuse_public_key: "pk-lf-2"
|
|
||||||
langfuse_secret: "sk-lf-2"
|
|
||||||
langfuse_host: ""
|
|
||||||
|
|
||||||
|
|
|
@ -3,14 +3,17 @@
|
||||||
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
|
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
|
||||||
### [BETA] this is in Beta. And might change.
|
### [BETA] this is in Beta. And might change.
|
||||||
|
|
||||||
from typing import Optional, Literal
|
import json
|
||||||
import litellm
|
import traceback
|
||||||
from litellm.caching import DualCache, RedisCache, InMemoryCache
|
from typing import Literal, Optional
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import json, traceback
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
|
||||||
|
|
||||||
class _PROXY_BatchRedisRequests(CustomLogger):
|
class _PROXY_BatchRedisRequests(CustomLogger):
|
||||||
|
@ -18,9 +21,10 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
||||||
in_memory_cache: Optional[InMemoryCache] = None
|
in_memory_cache: Optional[InMemoryCache] = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
litellm.cache.async_get_cache = (
|
if litellm.cache is not None:
|
||||||
self.async_get_cache
|
litellm.cache.async_get_cache = (
|
||||||
) # map the litellm 'get_cache' function to our custom function
|
self.async_get_cache
|
||||||
|
) # map the litellm 'get_cache' function to our custom function
|
||||||
|
|
||||||
def print_verbose(
|
def print_verbose(
|
||||||
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
|
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
|
||||||
|
@ -58,7 +62,7 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
||||||
if isinstance(key, str) and key.startswith(cache_key_name):
|
if isinstance(key, str) and key.startswith(cache_key_name):
|
||||||
in_memory_cache_exists = True
|
in_memory_cache_exists = True
|
||||||
|
|
||||||
if in_memory_cache_exists == False and litellm.cache is not None:
|
if in_memory_cache_exists is False and litellm.cache is not None:
|
||||||
"""
|
"""
|
||||||
- Check if `litellm.Cache` is redis
|
- Check if `litellm.Cache` is redis
|
||||||
- Get the relevant values
|
- Get the relevant values
|
||||||
|
@ -105,16 +109,25 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
||||||
"""
|
"""
|
||||||
- Check if the cache key is in-memory
|
- Check if the cache key is in-memory
|
||||||
|
|
||||||
- Else return None
|
- Else:
|
||||||
|
- add missing cache key from REDIS
|
||||||
|
- update in-memory cache
|
||||||
|
- return redis cache request
|
||||||
"""
|
"""
|
||||||
try: # never block execution
|
try: # never block execution
|
||||||
|
cache_key: Optional[str] = None
|
||||||
if "cache_key" in kwargs:
|
if "cache_key" in kwargs:
|
||||||
cache_key = kwargs["cache_key"]
|
cache_key = kwargs["cache_key"]
|
||||||
else:
|
elif litellm.cache is not None:
|
||||||
cache_key = litellm.cache.get_cache_key(
|
cache_key = litellm.cache.get_cache_key(
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
|
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
|
||||||
if cache_key is not None and self.in_memory_cache is not None:
|
|
||||||
|
if (
|
||||||
|
cache_key is not None
|
||||||
|
and self.in_memory_cache is not None
|
||||||
|
and litellm.cache is not None
|
||||||
|
):
|
||||||
cache_control_args = kwargs.get("cache", {})
|
cache_control_args = kwargs.get("cache", {})
|
||||||
max_age = cache_control_args.get(
|
max_age = cache_control_args.get(
|
||||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||||
|
@ -122,8 +135,16 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
||||||
cached_result = self.in_memory_cache.get_cache(
|
cached_result = self.in_memory_cache.get_cache(
|
||||||
cache_key, *args, **kwargs
|
cache_key, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
if cached_result is None:
|
||||||
|
cached_result = await litellm.cache.cache.async_get_cache(
|
||||||
|
cache_key, *args, **kwargs
|
||||||
|
)
|
||||||
|
if cached_result is not None:
|
||||||
|
await self.in_memory_cache.async_set_cache(
|
||||||
|
cache_key, cached_result, ttl=60
|
||||||
|
)
|
||||||
return litellm.cache._get_cache_logic(
|
return litellm.cache._get_cache_logic(
|
||||||
cached_result=cached_result, max_age=max_age
|
cached_result=cached_result, max_age=max_age
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -2100,3 +2100,59 @@ async def test_redis_sentinel_caching():
|
||||||
|
|
||||||
print(f"stored_val: {stored_val}")
|
print(f"stored_val: {stored_val}")
|
||||||
assert stored_val_2["id"] == response1.id
|
assert stored_val_2["id"] == response1.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_proxy_batch_redis_get_cache():
|
||||||
|
"""
|
||||||
|
Tests batch_redis_get.py
|
||||||
|
|
||||||
|
- make 1st call -> expect miss
|
||||||
|
- make 2nd call -> expect hit
|
||||||
|
"""
|
||||||
|
|
||||||
|
from litellm.caching import Cache, DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.hooks.batch_redis_get import _PROXY_BatchRedisRequests
|
||||||
|
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis",
|
||||||
|
host=os.getenv("REDIS_HOST"),
|
||||||
|
port=os.getenv("REDIS_PORT"),
|
||||||
|
password=os.getenv("REDIS_PASSWORD"),
|
||||||
|
namespace="test_namespace",
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_redis_get_obj = (
|
||||||
|
_PROXY_BatchRedisRequests()
|
||||||
|
) # overrides the .async_get_cache method
|
||||||
|
|
||||||
|
user_api_key_cache = DualCache()
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
batch_redis_get_obj.in_memory_cache = user_api_key_cache.in_memory_cache
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "hi {}".format(uuid.uuid4())}]
|
||||||
|
# 1st call -> expect miss
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
mock_response="hello",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "cache_key" not in response._hidden_params
|
||||||
|
print(response._hidden_params)
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# 2nd call -> expect hit
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
mock_response="hello",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response._hidden_params)
|
||||||
|
assert "cache_key" in response._hidden_params
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue