(bug fix) TTL not being set for embedding caching requests (#6095)

* fix ttl for cache pipeline settings

* add test for caching

* add test for setting ttls on redis caching
This commit is contained in:
Ishaan Jaff 2024-10-07 15:53:18 +05:30 committed by GitHub
parent 285b589095
commit 2c8bba293f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 99 additions and 14 deletions

View file

@ -18,7 +18,7 @@ import time
import traceback import traceback
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
from typing import Any, List, Literal, Optional, Union from typing import Any, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
@ -455,7 +455,9 @@ class RedisCache(BaseCache):
value, value,
) )
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs): async def async_set_cache_pipeline(
self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
):
""" """
Use Redis Pipelines for bulk write operations Use Redis Pipelines for bulk write operations
""" """
@ -467,6 +469,8 @@ class RedisCache(BaseCache):
_redis_client: Redis = self.init_async_client() # type: ignore _redis_client: Redis = self.init_async_client() # type: ignore
start_time = time.time() start_time = time.time()
ttl = ttl or kwargs.get("ttl", None)
print_verbose( print_verbose(
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}" f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
) )
@ -482,11 +486,10 @@ class RedisCache(BaseCache):
) )
json_cache_value = json.dumps(cache_value) json_cache_value = json.dumps(cache_value)
# Set the value with a TTL if it's provided. # Set the value with a TTL if it's provided.
_td: Optional[timedelta] = None
if ttl is not None: if ttl is not None:
pipe.setex(cache_key, ttl, json_cache_value) _td = timedelta(seconds=ttl)
else: pipe.set(cache_key, json_cache_value, ex=_td)
pipe.set(cache_key, json_cache_value)
# Execute the pipeline and return the results. # Execute the pipeline and return the results.
results = await pipe.execute() results = await pipe.execute()
@ -2188,14 +2191,38 @@ class Cache:
Args: Args:
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local". type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
# Redis Cache Args
host (str, optional): The host address for the Redis cache. Required if type is "redis". host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis".
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
ttl (float, optional): The ttl for the Redis cache
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
# Qdrant Cache Args
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic". qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic". qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic". similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
# Disk Cache Args
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
# S3 Cache Args
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
# Common Cache Args
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types. supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
**kwargs: Additional keyword arguments for redis.Redis() cache **kwargs: Additional keyword arguments for redis.Redis() cache
@ -2207,18 +2234,18 @@ class Cache:
""" """
if type == "redis": if type == "redis":
self.cache: BaseCache = RedisCache( self.cache: BaseCache = RedisCache(
host, host=host,
port, port=port,
password, password=password,
redis_flush_size, redis_flush_size=redis_flush_size,
startup_nodes=redis_startup_nodes, startup_nodes=redis_startup_nodes,
**kwargs, **kwargs,
) )
elif type == "redis-semantic": elif type == "redis-semantic":
self.cache = RedisSemanticCache( self.cache = RedisSemanticCache(
host, host=host,
port, port=port,
password, password=password,
similarity_threshold=similarity_threshold, similarity_threshold=similarity_threshold,
use_async=redis_semantic_cache_use_async, use_async=redis_semantic_cache_use_async,
embedding_model=redis_semantic_cache_embedding_model, embedding_model=redis_semantic_cache_embedding_model,
@ -2598,6 +2625,11 @@ class Cache:
try: try:
if self.should_use_cache(*args, **kwargs) is not True: if self.should_use_cache(*args, **kwargs) is not True:
return return
# set default ttl if not set
if self.ttl is not None:
kwargs["ttl"] = self.ttl
cache_list = [] cache_list = []
for idx, i in enumerate(kwargs["input"]): for idx, i in enumerate(kwargs["input"]):
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i}) preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
@ -2613,7 +2645,7 @@ class Cache:
self.cache, "async_set_cache_pipeline", None self.cache, "async_set_cache_pipeline", None
) )
if async_set_cache_pipeline: if async_set_cache_pipeline:
await async_set_cache_pipeline(cache_list=cache_list) await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
else: else:
tasks = [] tasks = []
for val in cache_list: for val in cache_list:

View file

@ -22,6 +22,9 @@ import litellm
from litellm import aembedding, completion, embedding from litellm import aembedding, completion, embedding
from litellm.caching import Cache from litellm.caching import Cache
from unittest.mock import AsyncMock, patch, MagicMock
import datetime
# litellm.set_verbose=True # litellm.set_verbose=True
messages = [{"role": "user", "content": "who is ishaan Github? "}] messages = [{"role": "user", "content": "who is ishaan Github? "}]
@ -579,6 +582,56 @@ async def test_embedding_caching_base_64():
assert embedding_val_2.data[1]["embedding"] == embedding_val_1.data[1]["embedding"] assert embedding_val_2.data[1]["embedding"] == embedding_val_1.data[1]["embedding"]
@pytest.mark.asyncio
async def test_embedding_caching_redis_ttl():
"""
Test default_in_redis_ttl is used for embedding caching
issue: https://github.com/BerriAI/litellm/issues/6010
"""
litellm.set_verbose = True
# Create a mock for the pipeline
mock_pipeline = AsyncMock()
mock_set = AsyncMock()
mock_pipeline.__aenter__.return_value.set = mock_set
# Patch the Redis class to return our mock
with patch("redis.asyncio.Redis.pipeline", return_value=mock_pipeline):
# Simulate the context manager behavior for the pipeline
litellm.cache = Cache(
type="redis",
host="dummy_host",
password="dummy_password",
default_in_redis_ttl=2.5,
)
inputs = [
f"{uuid.uuid4()} hello this is ishaan",
f"{uuid.uuid4()} hello this is ishaan again",
]
# Call the embedding method
embedding_val_1 = await litellm.aembedding(
model="azure/azure-embedding-model",
input=inputs,
encoding_format="base64",
)
await asyncio.sleep(3) # Wait for TTL to expire
# Check if set was called on the pipeline
mock_set.assert_called()
# Check if the TTL was set correctly
for call in mock_set.call_args_list:
args, kwargs = call
print(f"redis pipeline set args: {args}")
print(f"redis pipeline set kwargs: {kwargs}")
assert kwargs.get("ex") == datetime.timedelta(
seconds=2.5
) # Check if TTL is set to 2.5 seconds
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_redis_cache_basic(): async def test_redis_cache_basic():
""" """