From 2c8bba293f2042ff9ccab7d1457a2476a1aef501 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 7 Oct 2024 15:53:18 +0530 Subject: [PATCH] (bug fix) TTL not being set for embedding caching requests (#6095) * fix ttl for cache pipeline settings * add test for caching * add test for setting ttls on redis caching --- litellm/caching.py | 60 ++++++++++++++++++++++------- tests/local_testing/test_caching.py | 53 +++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 14 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 68e978e4e..91d9e6996 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -18,7 +18,7 @@ import time import traceback from datetime import timedelta from enum import Enum -from typing import Any, List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Tuple, Union from openai._models import BaseModel as OpenAIObject @@ -455,7 +455,9 @@ class RedisCache(BaseCache): value, ) - async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs): + async def async_set_cache_pipeline( + self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs + ): """ Use Redis Pipelines for bulk write operations """ @@ -467,6 +469,8 @@ class RedisCache(BaseCache): _redis_client: Redis = self.init_async_client() # type: ignore start_time = time.time() + ttl = ttl or kwargs.get("ttl", None) + print_verbose( f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}" ) @@ -482,11 +486,10 @@ class RedisCache(BaseCache): ) json_cache_value = json.dumps(cache_value) # Set the value with a TTL if it's provided. - + _td: Optional[timedelta] = None if ttl is not None: - pipe.setex(cache_key, ttl, json_cache_value) - else: - pipe.set(cache_key, json_cache_value) + _td = timedelta(seconds=ttl) + pipe.set(cache_key, json_cache_value, ex=_td) # Execute the pipeline and return the results. results = await pipe.execute() @@ -2188,14 +2191,38 @@ class Cache: Args: type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local". + + # Redis Cache Args host (str, optional): The host address for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis". + namespace (str, optional): The namespace for the Redis cache. Required if type is "redis". + ttl (float, optional): The ttl for the Redis cache + redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used. + redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None. + + # Qdrant Cache Args qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic". qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic". similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic". + # Disk Cache Args + disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None. + + # S3 Cache Args + s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None. + s3_region_name (str, optional): The region name for the s3 cache. Defaults to None. + s3_api_version (str, optional): The api version for the s3 cache. Defaults to None. + s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True. + s3_verify (bool, optional): The verify for the s3 cache. Defaults to None. + s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None. + s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None. + s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None. + s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None. + s3_config (dict, optional): The config for the s3 cache. Defaults to None. + + # Common Cache Args supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types. **kwargs: Additional keyword arguments for redis.Redis() cache @@ -2207,18 +2234,18 @@ class Cache: """ if type == "redis": self.cache: BaseCache = RedisCache( - host, - port, - password, - redis_flush_size, + host=host, + port=port, + password=password, + redis_flush_size=redis_flush_size, startup_nodes=redis_startup_nodes, **kwargs, ) elif type == "redis-semantic": self.cache = RedisSemanticCache( - host, - port, - password, + host=host, + port=port, + password=password, similarity_threshold=similarity_threshold, use_async=redis_semantic_cache_use_async, embedding_model=redis_semantic_cache_embedding_model, @@ -2598,6 +2625,11 @@ class Cache: try: if self.should_use_cache(*args, **kwargs) is not True: return + + # set default ttl if not set + if self.ttl is not None: + kwargs["ttl"] = self.ttl + cache_list = [] for idx, i in enumerate(kwargs["input"]): preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i}) @@ -2613,7 +2645,7 @@ class Cache: self.cache, "async_set_cache_pipeline", None ) if async_set_cache_pipeline: - await async_set_cache_pipeline(cache_list=cache_list) + await async_set_cache_pipeline(cache_list=cache_list, **kwargs) else: tasks = [] for val in cache_list: diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index 0b9ef30c1..0e9d1f6f2 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -22,6 +22,9 @@ import litellm from litellm import aembedding, completion, embedding from litellm.caching import Cache +from unittest.mock import AsyncMock, patch, MagicMock +import datetime + # litellm.set_verbose=True messages = [{"role": "user", "content": "who is ishaan Github? "}] @@ -579,6 +582,56 @@ async def test_embedding_caching_base_64(): assert embedding_val_2.data[1]["embedding"] == embedding_val_1.data[1]["embedding"] +@pytest.mark.asyncio +async def test_embedding_caching_redis_ttl(): + """ + Test default_in_redis_ttl is used for embedding caching + + issue: https://github.com/BerriAI/litellm/issues/6010 + """ + litellm.set_verbose = True + + # Create a mock for the pipeline + mock_pipeline = AsyncMock() + mock_set = AsyncMock() + mock_pipeline.__aenter__.return_value.set = mock_set + # Patch the Redis class to return our mock + with patch("redis.asyncio.Redis.pipeline", return_value=mock_pipeline): + # Simulate the context manager behavior for the pipeline + litellm.cache = Cache( + type="redis", + host="dummy_host", + password="dummy_password", + default_in_redis_ttl=2.5, + ) + + inputs = [ + f"{uuid.uuid4()} hello this is ishaan", + f"{uuid.uuid4()} hello this is ishaan again", + ] + + # Call the embedding method + embedding_val_1 = await litellm.aembedding( + model="azure/azure-embedding-model", + input=inputs, + encoding_format="base64", + ) + + await asyncio.sleep(3) # Wait for TTL to expire + + # Check if set was called on the pipeline + mock_set.assert_called() + + # Check if the TTL was set correctly + for call in mock_set.call_args_list: + args, kwargs = call + print(f"redis pipeline set args: {args}") + print(f"redis pipeline set kwargs: {kwargs}") + assert kwargs.get("ex") == datetime.timedelta( + seconds=2.5 + ) # Check if TTL is set to 2.5 seconds + + @pytest.mark.asyncio async def test_redis_cache_basic(): """