(refactor caching) use common _retrieve_from_cache helper (#6212)

* use folder for caching

* fix importing caching

* fix clickhouse pyright

* fix linting

* fix correctly pass kwargs and args

* fix test case for embedding

* fix linting

* fix embedding caching logic

* fix refactor handle utils.py

* refactor async set stream cache

* fix linting

* refactor - use _retrieve_from_cache

* refactor use _convert_cached_result_to_model_response

* fix linting errors
This commit is contained in:
Ishaan Jaff 2024-10-14 19:12:41 +05:30 committed by GitHub
parent 284deafd0d
commit 4ebe6813c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 217 additions and 135 deletions

View file

@ -25,8 +25,9 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.caching import *
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from litellm.types.utils import CachingSupportedCallTypes, all_litellm_params
from litellm.types.utils import all_litellm_params
def print_verbose(print_statement):
@ -2125,9 +2126,7 @@ class DualCache(BaseCache):
class Cache:
def __init__(
self,
type: Optional[
Literal["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"]
] = "local",
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
mode: Optional[
CacheMode
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
@ -2221,7 +2220,7 @@ class Cache:
Returns:
None. Cache is set as a litellm param
"""
if type == "redis":
if type == LiteLLMCacheType.REDIS:
self.cache: BaseCache = RedisCache(
host=host,
port=port,
@ -2230,7 +2229,7 @@ class Cache:
startup_nodes=redis_startup_nodes,
**kwargs,
)
elif type == "redis-semantic":
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
self.cache = RedisSemanticCache(
host=host,
port=port,
@ -2240,7 +2239,7 @@ class Cache:
embedding_model=redis_semantic_cache_embedding_model,
**kwargs,
)
elif type == "qdrant-semantic":
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
self.cache = QdrantSemanticCache(
qdrant_api_base=qdrant_api_base,
qdrant_api_key=qdrant_api_key,
@ -2249,9 +2248,9 @@ class Cache:
quantization_config=qdrant_quantization_config,
embedding_model=qdrant_semantic_cache_embedding_model,
)
elif type == "local":
elif type == LiteLLMCacheType.LOCAL:
self.cache = InMemoryCache()
elif type == "s3":
elif type == LiteLLMCacheType.S3:
self.cache = S3Cache(
s3_bucket_name=s3_bucket_name,
s3_region_name=s3_region_name,
@ -2266,7 +2265,7 @@ class Cache:
s3_path=s3_path,
**kwargs,
)
elif type == "disk":
elif type == LiteLLMCacheType.DISK:
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
@ -2281,11 +2280,12 @@ class Cache:
self.ttl = ttl
self.mode: CacheMode = mode or CacheMode.default_on
if self.type == "local" and default_in_memory_ttl is not None:
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
self.ttl = default_in_memory_ttl
if (
self.type == "redis" or self.type == "redis-semantic"
self.type == LiteLLMCacheType.REDIS
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
) and default_in_redis_ttl is not None:
self.ttl = default_in_redis_ttl
@ -2694,6 +2694,17 @@ class Cache:
if hasattr(self.cache, "disconnect"):
await self.cache.disconnect()
def _supports_async(self) -> bool:
"""
Internal method to check if the cache type supports async get/set operations
Only S3 Cache Does NOT support async operations
"""
if self.type and self.type == LiteLLMCacheType.S3:
return False
return True
class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None):
@ -2774,7 +2785,7 @@ class DiskCache(BaseCache):
def enable_cache(
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
@ -2832,7 +2843,7 @@ def enable_cache(
def update_cache(
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,