(refactor caching) use common _retrieve_from_cache helper (#6212)

* use folder for caching

* fix importing caching

* fix clickhouse pyright

* fix linting

* fix correctly pass kwargs and args

* fix test case for embedding

* fix linting

* fix embedding caching logic

* fix refactor handle utils.py

* refactor async set stream cache

* fix linting

* refactor - use _retrieve_from_cache

* refactor use _convert_cached_result_to_model_response

* fix linting errors
This commit is contained in:
Ishaan Jaff 2024-10-14 19:12:41 +05:30 committed by GitHub
parent 284deafd0d
commit 4ebe6813c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 217 additions and 135 deletions

View file

@ -25,8 +25,9 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.caching import *
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from litellm.types.utils import CachingSupportedCallTypes, all_litellm_params
from litellm.types.utils import all_litellm_params
def print_verbose(print_statement):
@ -2125,9 +2126,7 @@ class DualCache(BaseCache):
class Cache:
def __init__(
self,
type: Optional[
Literal["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"]
] = "local",
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
mode: Optional[
CacheMode
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
@ -2221,7 +2220,7 @@ class Cache:
Returns:
None. Cache is set as a litellm param
"""
if type == "redis":
if type == LiteLLMCacheType.REDIS:
self.cache: BaseCache = RedisCache(
host=host,
port=port,
@ -2230,7 +2229,7 @@ class Cache:
startup_nodes=redis_startup_nodes,
**kwargs,
)
elif type == "redis-semantic":
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
self.cache = RedisSemanticCache(
host=host,
port=port,
@ -2240,7 +2239,7 @@ class Cache:
embedding_model=redis_semantic_cache_embedding_model,
**kwargs,
)
elif type == "qdrant-semantic":
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
self.cache = QdrantSemanticCache(
qdrant_api_base=qdrant_api_base,
qdrant_api_key=qdrant_api_key,
@ -2249,9 +2248,9 @@ class Cache:
quantization_config=qdrant_quantization_config,
embedding_model=qdrant_semantic_cache_embedding_model,
)
elif type == "local":
elif type == LiteLLMCacheType.LOCAL:
self.cache = InMemoryCache()
elif type == "s3":
elif type == LiteLLMCacheType.S3:
self.cache = S3Cache(
s3_bucket_name=s3_bucket_name,
s3_region_name=s3_region_name,
@ -2266,7 +2265,7 @@ class Cache:
s3_path=s3_path,
**kwargs,
)
elif type == "disk":
elif type == LiteLLMCacheType.DISK:
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
@ -2281,11 +2280,12 @@ class Cache:
self.ttl = ttl
self.mode: CacheMode = mode or CacheMode.default_on
if self.type == "local" and default_in_memory_ttl is not None:
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
self.ttl = default_in_memory_ttl
if (
self.type == "redis" or self.type == "redis-semantic"
self.type == LiteLLMCacheType.REDIS
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
) and default_in_redis_ttl is not None:
self.ttl = default_in_redis_ttl
@ -2694,6 +2694,17 @@ class Cache:
if hasattr(self.cache, "disconnect"):
await self.cache.disconnect()
def _supports_async(self) -> bool:
"""
Internal method to check if the cache type supports async get/set operations
Only S3 Cache Does NOT support async operations
"""
if self.type and self.type == LiteLLMCacheType.S3:
return False
return True
class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None):
@ -2774,7 +2785,7 @@ class DiskCache(BaseCache):
def enable_cache(
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
@ -2832,7 +2843,7 @@ def enable_cache(
def update_cache(
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,

View file

@ -99,11 +99,7 @@ class LLMCachingHandler:
Raises:
None
"""
from litellm.utils import (
CustomStreamWrapper,
convert_to_model_response_object,
convert_to_streaming_response_async,
)
from litellm.utils import CustomStreamWrapper
args = args or ()
@ -124,43 +120,13 @@ class LLMCachingHandler:
in litellm.cache.supported_call_types
):
print_verbose("Checking Cache")
if call_type == CallTypes.aembedding.value and isinstance(
kwargs["input"], list
):
tasks = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
tasks.append(
litellm.cache.async_get_cache(cache_key=preset_cache_key)
)
cached_result = await asyncio.gather(*tasks)
## check if cached result is None ##
if cached_result is not None and isinstance(cached_result, list):
# set cached_result to None if all elements are None
if all(result is None for result in cached_result):
cached_result = None
elif isinstance(litellm.cache.cache, RedisSemanticCache) or isinstance(
litellm.cache.cache, RedisCache
):
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
elif isinstance(litellm.cache.cache, QdrantSemanticCache):
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = litellm.cache.get_cache(*args, **kwargs)
cached_result = await self._retrieve_from_cache(
call_type=call_type,
kwargs=kwargs,
args=args,
)
if cached_result is not None and not isinstance(cached_result, list):
print_verbose("Cache Hit!")
cache_hit = True
@ -202,69 +168,16 @@ class LLMCachingHandler:
stream=kwargs.get("stream", False),
)
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) is True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
if call_type == CallTypes.atext_completion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) is True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
cached_result = TextCompletionResponse(**cached_result)
elif call_type == CallTypes.aembedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
elif call_type == CallTypes.arerank.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=None,
response_type="rerank",
)
elif call_type == CallTypes.atranscription.value and isinstance(
cached_result, dict
):
hidden_params = {
"model": "whisper-1",
"custom_llm_provider": custom_llm_provider,
"cache_hit": True,
}
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=TranscriptionResponse(),
response_type="audio_transcription",
hidden_params=hidden_params,
)
cached_result = self._convert_cached_result_to_model_response(
cached_result=cached_result,
call_type=call_type,
kwargs=kwargs,
logging_obj=logging_obj,
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
args=args,
)
if kwargs.get("stream", False) is False:
# LOG SUCCESS
asyncio.create_task(
@ -387,6 +300,151 @@ class LLMCachingHandler:
final_embedding_cached_response=final_embedding_cached_response,
)
async def _retrieve_from_cache(
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
) -> Optional[Any]:
"""
Internal method to
- get cache key
- check what type of cache is used - Redis, RedisSemantic, Qdrant, S3
- async get cache value
- return the cached value
Args:
call_type: str:
kwargs: Dict[str, Any]:
args: Optional[Tuple[Any, ...]] = None:
Returns:
Optional[Any]:
Raises:
None
"""
if litellm.cache is None:
return None
cached_result: Optional[Any] = None
if call_type == CallTypes.aembedding.value and isinstance(
kwargs["input"], list
):
tasks = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
cached_result = await asyncio.gather(*tasks)
## check if cached result is None ##
if cached_result is not None and isinstance(cached_result, list):
# set cached_result to None if all elements are None
if all(result is None for result in cached_result):
cached_result = None
else:
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
if litellm.cache._supports_async() is True:
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
cached_result = litellm.cache.get_cache(*args, **kwargs)
return cached_result
def _convert_cached_result_to_model_response(
self,
cached_result: Any,
call_type: str,
kwargs: Dict[str, Any],
logging_obj: LiteLLMLoggingObj,
model: str,
args: Tuple[Any, ...],
custom_llm_provider: Optional[str] = None,
) -> Optional[Any]:
"""
Internal method to process the cached result
Checks the call type and converts the cached result to the appropriate model response object
example if call type is text_completion -> returns TextCompletionResponse object
Args:
cached_result: Any:
call_type: str:
kwargs: Dict[str, Any]:
logging_obj: LiteLLMLoggingObj:
model: str:
custom_llm_provider: Optional[str] = None:
args: Optional[Tuple[Any, ...]] = None:
Returns:
Optional[Any]:
"""
from litellm.utils import (
CustomStreamWrapper,
convert_to_model_response_object,
convert_to_streaming_response_async,
)
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
if kwargs.get("stream", False) is True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
if call_type == CallTypes.atext_completion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) is True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
cached_result = TextCompletionResponse(**cached_result)
elif call_type == CallTypes.aembedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
elif call_type == CallTypes.arerank.value and isinstance(cached_result, dict):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=None,
response_type="rerank",
)
elif call_type == CallTypes.atranscription.value and isinstance(
cached_result, dict
):
hidden_params = {
"model": "whisper-1",
"custom_llm_provider": custom_llm_provider,
"cache_hit": True,
}
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=TranscriptionResponse(),
response_type="audio_transcription",
hidden_params=hidden_params,
)
return cached_result
async def _async_set_cache(
self,
result: Any,

View file

@ -4,7 +4,7 @@ from typing import Callable, List, Literal, Optional, Tuple, Union
import httpx
import litellm
from litellm.caching.caching import Cache
from litellm.caching.caching import Cache, LiteLLMCacheType
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.OpenAI.openai import AllMessageValues
@ -22,7 +22,9 @@ from .transformation import (
transform_openai_messages_to_gemini_context_caching,
)
local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function
local_cache_obj = Cache(
type=LiteLLMCacheType.LOCAL
) # only used for calling 'get_cache_key' function
class ContextCachingEndpoints(VertexBase):

25
litellm/types/caching.py Normal file
View file

@ -0,0 +1,25 @@
from enum import Enum
from typing import Literal
class LiteLLMCacheType(str, Enum):
LOCAL = "local"
REDIS = "redis"
REDIS_SEMANTIC = "redis-semantic"
S3 = "s3"
DISK = "disk"
QDRANT_SEMANTIC = "qdrant-semantic"
CachingSupportedCallTypes = Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
"atext_completion",
"text_completion",
"arerank",
"rerank",
]

View file

@ -1418,17 +1418,3 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
# GCS dynamic params
gcs_bucket_name: Optional[str]
gcs_path_service_account: Optional[str]
CachingSupportedCallTypes = Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
"atext_completion",
"text_completion",
"arerank",
"rerank",
]