forked from phoenix/litellm-mirror
fix(batch_redis_get.py): handle custom namespace
Fix https://github.com/BerriAI/litellm/issues/5917
This commit is contained in:
parent
e9e086a0b6
commit
efc06d4a03
5 changed files with 105 additions and 37 deletions
|
@ -2401,13 +2401,13 @@ class Cache:
|
|||
# Hexadecimal representation of the hash
|
||||
hash_hex = hash_object.hexdigest()
|
||||
print_verbose(f"Hashed cache key (SHA-256): {hash_hex}")
|
||||
if self.namespace is not None:
|
||||
hash_hex = f"{self.namespace}:{hash_hex}"
|
||||
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
||||
elif kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
|
||||
if kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
|
||||
_namespace = kwargs.get("metadata", {}).get("redis_namespace", None)
|
||||
hash_hex = f"{_namespace}:{hash_hex}"
|
||||
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
||||
elif self.namespace is not None:
|
||||
hash_hex = f"{self.namespace}:{hash_hex}"
|
||||
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
||||
return hash_hex
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
|
|
|
@ -2141,7 +2141,7 @@ def _init_custom_logger_compatible_class(
|
|||
llm_router: Optional[
|
||||
Any
|
||||
], # expect litellm.Router, but typing errors due to circular import
|
||||
premium_user: bool = False,
|
||||
premium_user: Optional[bool] = None,
|
||||
) -> Optional[CustomLogger]:
|
||||
if logging_integration == "lago":
|
||||
for callback in _in_memory_loggers:
|
||||
|
@ -2184,7 +2184,7 @@ def _init_custom_logger_compatible_class(
|
|||
_prometheus_logger = PrometheusLogger()
|
||||
_in_memory_loggers.append(_prometheus_logger)
|
||||
return _prometheus_logger # type: ignore
|
||||
else:
|
||||
elif premium_user is False:
|
||||
verbose_logger.warning(
|
||||
f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
|
|
@ -52,19 +52,10 @@ model_list:
|
|||
model: "vertex_ai/gemini-flash-experimental"
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["prometheus"]
|
||||
redact_user_api_key_info: true
|
||||
|
||||
default_team_settings:
|
||||
- team_id: "09ae376d-f6c8-42cd-88be-59717135684d" # team 1
|
||||
success_callbacks: ["langfuse"]
|
||||
langfuse_public_key: "pk-lf-1"
|
||||
langfuse_secret: "sk-lf-1"
|
||||
langfuse_host: ""
|
||||
|
||||
- team_id: "e5db79db-d623-4a5b-afd5-162be56074df" # team2
|
||||
success_callback: ["langfuse"]
|
||||
langfuse_public_key: "pk-lf-2"
|
||||
langfuse_secret: "sk-lf-2"
|
||||
langfuse_host: ""
|
||||
|
||||
json_logs: true
|
||||
cache: true
|
||||
cache_params:
|
||||
type: "redis"
|
||||
# namespace: "litellm_caching"
|
||||
ttl: 900
|
||||
callbacks: ["batch_redis_requests"]
|
||||
|
|
|
@ -3,14 +3,17 @@
|
|||
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
|
||||
### [BETA] this is in Beta. And might change.
|
||||
|
||||
from typing import Optional, Literal
|
||||
import litellm
|
||||
from litellm.caching import DualCache, RedisCache, InMemoryCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
import json
|
||||
import traceback
|
||||
from typing import Literal, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
import json, traceback
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class _PROXY_BatchRedisRequests(CustomLogger):
|
||||
|
@ -18,9 +21,10 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
|||
in_memory_cache: Optional[InMemoryCache] = None
|
||||
|
||||
def __init__(self):
|
||||
litellm.cache.async_get_cache = (
|
||||
self.async_get_cache
|
||||
) # map the litellm 'get_cache' function to our custom function
|
||||
if litellm.cache is not None:
|
||||
litellm.cache.async_get_cache = (
|
||||
self.async_get_cache
|
||||
) # map the litellm 'get_cache' function to our custom function
|
||||
|
||||
def print_verbose(
|
||||
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
|
||||
|
@ -58,7 +62,7 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
|||
if isinstance(key, str) and key.startswith(cache_key_name):
|
||||
in_memory_cache_exists = True
|
||||
|
||||
if in_memory_cache_exists == False and litellm.cache is not None:
|
||||
if in_memory_cache_exists is False and litellm.cache is not None:
|
||||
"""
|
||||
- Check if `litellm.Cache` is redis
|
||||
- Get the relevant values
|
||||
|
@ -105,16 +109,25 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
|||
"""
|
||||
- Check if the cache key is in-memory
|
||||
|
||||
- Else return None
|
||||
- Else:
|
||||
- add missing cache key from REDIS
|
||||
- update in-memory cache
|
||||
- return redis cache request
|
||||
"""
|
||||
try: # never block execution
|
||||
cache_key: Optional[str] = None
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
elif litellm.cache is not None:
|
||||
cache_key = litellm.cache.get_cache_key(
|
||||
*args, **kwargs
|
||||
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
|
||||
if cache_key is not None and self.in_memory_cache is not None:
|
||||
|
||||
if (
|
||||
cache_key is not None
|
||||
and self.in_memory_cache is not None
|
||||
and litellm.cache is not None
|
||||
):
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
|
@ -122,8 +135,16 @@ class _PROXY_BatchRedisRequests(CustomLogger):
|
|||
cached_result = self.in_memory_cache.get_cache(
|
||||
cache_key, *args, **kwargs
|
||||
)
|
||||
if cached_result is None:
|
||||
cached_result = await litellm.cache.cache.async_get_cache(
|
||||
cache_key, *args, **kwargs
|
||||
)
|
||||
if cached_result is not None:
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
cache_key, cached_result, ttl=60
|
||||
)
|
||||
return litellm.cache._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
|
|
@ -2100,3 +2100,59 @@ async def test_redis_sentinel_caching():
|
|||
|
||||
print(f"stored_val: {stored_val}")
|
||||
assert stored_val_2["id"] == response1.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_proxy_batch_redis_get_cache():
|
||||
"""
|
||||
Tests batch_redis_get.py
|
||||
|
||||
- make 1st call -> expect miss
|
||||
- make 2nd call -> expect hit
|
||||
"""
|
||||
|
||||
from litellm.caching import Cache, DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.hooks.batch_redis_get import _PROXY_BatchRedisRequests
|
||||
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.getenv("REDIS_HOST"),
|
||||
port=os.getenv("REDIS_PORT"),
|
||||
password=os.getenv("REDIS_PASSWORD"),
|
||||
namespace="test_namespace",
|
||||
)
|
||||
|
||||
batch_redis_get_obj = (
|
||||
_PROXY_BatchRedisRequests()
|
||||
) # overrides the .async_get_cache method
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
|
||||
import uuid
|
||||
|
||||
batch_redis_get_obj.in_memory_cache = user_api_key_cache.in_memory_cache
|
||||
|
||||
messages = [{"role": "user", "content": "hi {}".format(uuid.uuid4())}]
|
||||
# 1st call -> expect miss
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
mock_response="hello",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert "cache_key" not in response._hidden_params
|
||||
print(response._hidden_params)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 2nd call -> expect hit
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
mock_response="hello",
|
||||
)
|
||||
|
||||
print(response._hidden_params)
|
||||
assert "cache_key" in response._hidden_params
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue