(bug fix) TTL not being set for embedding caching requests (#6095)

* fix ttl for cache pipeline settings

* add test for caching

* add test for setting ttls on redis caching
This commit is contained in:
Ishaan Jaff 2024-10-07 15:53:18 +05:30 committed by GitHub
parent 285b589095
commit 2c8bba293f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 99 additions and 14 deletions

View file

@ -18,7 +18,7 @@ import time
import traceback
from datetime import timedelta
from enum import Enum
from typing import Any, List, Literal, Optional, Union
from typing import Any, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject
@ -455,7 +455,9 @@ class RedisCache(BaseCache):
value,
)
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
async def async_set_cache_pipeline(
self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
):
"""
Use Redis Pipelines for bulk write operations
"""
@ -467,6 +469,8 @@ class RedisCache(BaseCache):
_redis_client: Redis = self.init_async_client() # type: ignore
start_time = time.time()
ttl = ttl or kwargs.get("ttl", None)
print_verbose(
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
)
@ -482,11 +486,10 @@ class RedisCache(BaseCache):
)
json_cache_value = json.dumps(cache_value)
# Set the value with a TTL if it's provided.
_td: Optional[timedelta] = None
if ttl is not None:
pipe.setex(cache_key, ttl, json_cache_value)
else:
pipe.set(cache_key, json_cache_value)
_td = timedelta(seconds=ttl)
pipe.set(cache_key, json_cache_value, ex=_td)
# Execute the pipeline and return the results.
results = await pipe.execute()
@ -2188,14 +2191,38 @@ class Cache:
Args:
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
# Redis Cache Args
host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis".
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
ttl (float, optional): The ttl for the Redis cache
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
# Qdrant Cache Args
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
# Disk Cache Args
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
# S3 Cache Args
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
# Common Cache Args
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
**kwargs: Additional keyword arguments for redis.Redis() cache
@ -2207,18 +2234,18 @@ class Cache:
"""
if type == "redis":
self.cache: BaseCache = RedisCache(
host,
port,
password,
redis_flush_size,
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
startup_nodes=redis_startup_nodes,
**kwargs,
)
elif type == "redis-semantic":
self.cache = RedisSemanticCache(
host,
port,
password,
host=host,
port=port,
password=password,
similarity_threshold=similarity_threshold,
use_async=redis_semantic_cache_use_async,
embedding_model=redis_semantic_cache_embedding_model,
@ -2598,6 +2625,11 @@ class Cache:
try:
if self.should_use_cache(*args, **kwargs) is not True:
return
# set default ttl if not set
if self.ttl is not None:
kwargs["ttl"] = self.ttl
cache_list = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
@ -2613,7 +2645,7 @@ class Cache:
self.cache, "async_set_cache_pipeline", None
)
if async_set_cache_pipeline:
await async_set_cache_pipeline(cache_list=cache_list)
await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
else:
tasks = []
for val in cache_list: