forked from phoenix/litellm-mirror
(bug fix) TTL not being set for embedding caching requests (#6095)
* fix ttl for cache pipeline settings * add test for caching * add test for setting ttls on redis caching
This commit is contained in:
parent
285b589095
commit
2c8bba293f
2 changed files with 99 additions and 14 deletions
|
@ -18,7 +18,7 @@ import time
|
|||
import traceback
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
|
||||
|
@ -455,7 +455,9 @@ class RedisCache(BaseCache):
|
|||
value,
|
||||
)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Use Redis Pipelines for bulk write operations
|
||||
"""
|
||||
|
@ -467,6 +469,8 @@ class RedisCache(BaseCache):
|
|||
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||
start_time = time.time()
|
||||
|
||||
ttl = ttl or kwargs.get("ttl", None)
|
||||
|
||||
print_verbose(
|
||||
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
|
||||
)
|
||||
|
@ -482,11 +486,10 @@ class RedisCache(BaseCache):
|
|||
)
|
||||
json_cache_value = json.dumps(cache_value)
|
||||
# Set the value with a TTL if it's provided.
|
||||
|
||||
_td: Optional[timedelta] = None
|
||||
if ttl is not None:
|
||||
pipe.setex(cache_key, ttl, json_cache_value)
|
||||
else:
|
||||
pipe.set(cache_key, json_cache_value)
|
||||
_td = timedelta(seconds=ttl)
|
||||
pipe.set(cache_key, json_cache_value, ex=_td)
|
||||
# Execute the pipeline and return the results.
|
||||
results = await pipe.execute()
|
||||
|
||||
|
@ -2188,14 +2191,38 @@ class Cache:
|
|||
|
||||
Args:
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
||||
|
||||
# Redis Cache Args
|
||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
|
||||
ttl (float, optional): The ttl for the Redis cache
|
||||
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
|
||||
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
|
||||
|
||||
# Qdrant Cache Args
|
||||
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
||||
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
||||
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
||||
|
||||
# Disk Cache Args
|
||||
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
|
||||
|
||||
# S3 Cache Args
|
||||
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
|
||||
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
|
||||
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
|
||||
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
|
||||
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
|
||||
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
|
||||
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
|
||||
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
|
||||
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
|
||||
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
|
||||
|
||||
# Common Cache Args
|
||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||
|
||||
|
@ -2207,18 +2234,18 @@ class Cache:
|
|||
"""
|
||||
if type == "redis":
|
||||
self.cache: BaseCache = RedisCache(
|
||||
host,
|
||||
port,
|
||||
password,
|
||||
redis_flush_size,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
redis_flush_size=redis_flush_size,
|
||||
startup_nodes=redis_startup_nodes,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "redis-semantic":
|
||||
self.cache = RedisSemanticCache(
|
||||
host,
|
||||
port,
|
||||
password,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
similarity_threshold=similarity_threshold,
|
||||
use_async=redis_semantic_cache_use_async,
|
||||
embedding_model=redis_semantic_cache_embedding_model,
|
||||
|
@ -2598,6 +2625,11 @@ class Cache:
|
|||
try:
|
||||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
return
|
||||
|
||||
# set default ttl if not set
|
||||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
|
||||
cache_list = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
|
||||
|
@ -2613,7 +2645,7 @@ class Cache:
|
|||
self.cache, "async_set_cache_pipeline", None
|
||||
)
|
||||
if async_set_cache_pipeline:
|
||||
await async_set_cache_pipeline(cache_list=cache_list)
|
||||
await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||
else:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
|
|
|
@ -22,6 +22,9 @@ import litellm
|
|||
from litellm import aembedding, completion, embedding
|
||||
from litellm.caching import Cache
|
||||
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import datetime
|
||||
|
||||
# litellm.set_verbose=True
|
||||
|
||||
messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
||||
|
@ -579,6 +582,56 @@ async def test_embedding_caching_base_64():
|
|||
assert embedding_val_2.data[1]["embedding"] == embedding_val_1.data[1]["embedding"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_caching_redis_ttl():
|
||||
"""
|
||||
Test default_in_redis_ttl is used for embedding caching
|
||||
|
||||
issue: https://github.com/BerriAI/litellm/issues/6010
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Create a mock for the pipeline
|
||||
mock_pipeline = AsyncMock()
|
||||
mock_set = AsyncMock()
|
||||
mock_pipeline.__aenter__.return_value.set = mock_set
|
||||
# Patch the Redis class to return our mock
|
||||
with patch("redis.asyncio.Redis.pipeline", return_value=mock_pipeline):
|
||||
# Simulate the context manager behavior for the pipeline
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host="dummy_host",
|
||||
password="dummy_password",
|
||||
default_in_redis_ttl=2.5,
|
||||
)
|
||||
|
||||
inputs = [
|
||||
f"{uuid.uuid4()} hello this is ishaan",
|
||||
f"{uuid.uuid4()} hello this is ishaan again",
|
||||
]
|
||||
|
||||
# Call the embedding method
|
||||
embedding_val_1 = await litellm.aembedding(
|
||||
model="azure/azure-embedding-model",
|
||||
input=inputs,
|
||||
encoding_format="base64",
|
||||
)
|
||||
|
||||
await asyncio.sleep(3) # Wait for TTL to expire
|
||||
|
||||
# Check if set was called on the pipeline
|
||||
mock_set.assert_called()
|
||||
|
||||
# Check if the TTL was set correctly
|
||||
for call in mock_set.call_args_list:
|
||||
args, kwargs = call
|
||||
print(f"redis pipeline set args: {args}")
|
||||
print(f"redis pipeline set kwargs: {kwargs}")
|
||||
assert kwargs.get("ex") == datetime.timedelta(
|
||||
seconds=2.5
|
||||
) # Check if TTL is set to 2.5 seconds
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_basic():
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue