forked from phoenix/litellm-mirror
(bug fix) TTL not being set for embedding caching requests (#6095)
* fix ttl for cache pipeline settings * add test for caching * add test for setting ttls on redis caching
This commit is contained in:
parent
285b589095
commit
2c8bba293f
2 changed files with 99 additions and 14 deletions
|
@ -18,7 +18,7 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Literal, Optional, Union
|
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
|
|
||||||
|
@ -455,7 +455,9 @@ class RedisCache(BaseCache):
|
||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
|
async def async_set_cache_pipeline(
|
||||||
|
self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Use Redis Pipelines for bulk write operations
|
Use Redis Pipelines for bulk write operations
|
||||||
"""
|
"""
|
||||||
|
@ -467,6 +469,8 @@ class RedisCache(BaseCache):
|
||||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
ttl = ttl or kwargs.get("ttl", None)
|
||||||
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
|
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
|
||||||
)
|
)
|
||||||
|
@ -482,11 +486,10 @@ class RedisCache(BaseCache):
|
||||||
)
|
)
|
||||||
json_cache_value = json.dumps(cache_value)
|
json_cache_value = json.dumps(cache_value)
|
||||||
# Set the value with a TTL if it's provided.
|
# Set the value with a TTL if it's provided.
|
||||||
|
_td: Optional[timedelta] = None
|
||||||
if ttl is not None:
|
if ttl is not None:
|
||||||
pipe.setex(cache_key, ttl, json_cache_value)
|
_td = timedelta(seconds=ttl)
|
||||||
else:
|
pipe.set(cache_key, json_cache_value, ex=_td)
|
||||||
pipe.set(cache_key, json_cache_value)
|
|
||||||
# Execute the pipeline and return the results.
|
# Execute the pipeline and return the results.
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
|
|
||||||
|
@ -2188,14 +2191,38 @@ class Cache:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
||||||
|
|
||||||
|
# Redis Cache Args
|
||||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||||
|
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
|
||||||
|
ttl (float, optional): The ttl for the Redis cache
|
||||||
|
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
|
||||||
|
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
|
||||||
|
|
||||||
|
# Qdrant Cache Args
|
||||||
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
||||||
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
||||||
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
||||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
||||||
|
|
||||||
|
# Disk Cache Args
|
||||||
|
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
|
||||||
|
|
||||||
|
# S3 Cache Args
|
||||||
|
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
|
||||||
|
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
|
||||||
|
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
|
||||||
|
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
|
||||||
|
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
|
||||||
|
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
|
||||||
|
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
|
||||||
|
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
|
||||||
|
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
|
||||||
|
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
|
||||||
|
|
||||||
|
# Common Cache Args
|
||||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||||
|
|
||||||
|
@ -2207,18 +2234,18 @@ class Cache:
|
||||||
"""
|
"""
|
||||||
if type == "redis":
|
if type == "redis":
|
||||||
self.cache: BaseCache = RedisCache(
|
self.cache: BaseCache = RedisCache(
|
||||||
host,
|
host=host,
|
||||||
port,
|
port=port,
|
||||||
password,
|
password=password,
|
||||||
redis_flush_size,
|
redis_flush_size=redis_flush_size,
|
||||||
startup_nodes=redis_startup_nodes,
|
startup_nodes=redis_startup_nodes,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif type == "redis-semantic":
|
elif type == "redis-semantic":
|
||||||
self.cache = RedisSemanticCache(
|
self.cache = RedisSemanticCache(
|
||||||
host,
|
host=host,
|
||||||
port,
|
port=port,
|
||||||
password,
|
password=password,
|
||||||
similarity_threshold=similarity_threshold,
|
similarity_threshold=similarity_threshold,
|
||||||
use_async=redis_semantic_cache_use_async,
|
use_async=redis_semantic_cache_use_async,
|
||||||
embedding_model=redis_semantic_cache_embedding_model,
|
embedding_model=redis_semantic_cache_embedding_model,
|
||||||
|
@ -2598,6 +2625,11 @@ class Cache:
|
||||||
try:
|
try:
|
||||||
if self.should_use_cache(*args, **kwargs) is not True:
|
if self.should_use_cache(*args, **kwargs) is not True:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# set default ttl if not set
|
||||||
|
if self.ttl is not None:
|
||||||
|
kwargs["ttl"] = self.ttl
|
||||||
|
|
||||||
cache_list = []
|
cache_list = []
|
||||||
for idx, i in enumerate(kwargs["input"]):
|
for idx, i in enumerate(kwargs["input"]):
|
||||||
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
|
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
|
||||||
|
@ -2613,7 +2645,7 @@ class Cache:
|
||||||
self.cache, "async_set_cache_pipeline", None
|
self.cache, "async_set_cache_pipeline", None
|
||||||
)
|
)
|
||||||
if async_set_cache_pipeline:
|
if async_set_cache_pipeline:
|
||||||
await async_set_cache_pipeline(cache_list=cache_list)
|
await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||||
else:
|
else:
|
||||||
tasks = []
|
tasks = []
|
||||||
for val in cache_list:
|
for val in cache_list:
|
||||||
|
|
|
@ -22,6 +22,9 @@ import litellm
|
||||||
from litellm import aembedding, completion, embedding
|
from litellm import aembedding, completion, embedding
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
import datetime
|
||||||
|
|
||||||
# litellm.set_verbose=True
|
# litellm.set_verbose=True
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
||||||
|
@ -579,6 +582,56 @@ async def test_embedding_caching_base_64():
|
||||||
assert embedding_val_2.data[1]["embedding"] == embedding_val_1.data[1]["embedding"]
|
assert embedding_val_2.data[1]["embedding"] == embedding_val_1.data[1]["embedding"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_caching_redis_ttl():
|
||||||
|
"""
|
||||||
|
Test default_in_redis_ttl is used for embedding caching
|
||||||
|
|
||||||
|
issue: https://github.com/BerriAI/litellm/issues/6010
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
# Create a mock for the pipeline
|
||||||
|
mock_pipeline = AsyncMock()
|
||||||
|
mock_set = AsyncMock()
|
||||||
|
mock_pipeline.__aenter__.return_value.set = mock_set
|
||||||
|
# Patch the Redis class to return our mock
|
||||||
|
with patch("redis.asyncio.Redis.pipeline", return_value=mock_pipeline):
|
||||||
|
# Simulate the context manager behavior for the pipeline
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis",
|
||||||
|
host="dummy_host",
|
||||||
|
password="dummy_password",
|
||||||
|
default_in_redis_ttl=2.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
f"{uuid.uuid4()} hello this is ishaan",
|
||||||
|
f"{uuid.uuid4()} hello this is ishaan again",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the embedding method
|
||||||
|
embedding_val_1 = await litellm.aembedding(
|
||||||
|
model="azure/azure-embedding-model",
|
||||||
|
input=inputs,
|
||||||
|
encoding_format="base64",
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(3) # Wait for TTL to expire
|
||||||
|
|
||||||
|
# Check if set was called on the pipeline
|
||||||
|
mock_set.assert_called()
|
||||||
|
|
||||||
|
# Check if the TTL was set correctly
|
||||||
|
for call in mock_set.call_args_list:
|
||||||
|
args, kwargs = call
|
||||||
|
print(f"redis pipeline set args: {args}")
|
||||||
|
print(f"redis pipeline set kwargs: {kwargs}")
|
||||||
|
assert kwargs.get("ex") == datetime.timedelta(
|
||||||
|
seconds=2.5
|
||||||
|
) # Check if TTL is set to 2.5 seconds
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redis_cache_basic():
|
async def test_redis_cache_basic():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue