From 4ebe6813c0657156beaa0f0786258b5b9edbe72d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 14 Oct 2024 19:12:41 +0530 Subject: [PATCH] (refactor caching) use common `_retrieve_from_cache` helper (#6212) * use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * refactor async set stream cache * fix linting * refactor - use _retrieve_from_cache * refactor use _convert_cached_result_to_model_response * fix linting errors --- litellm/caching/caching.py | 39 ++- litellm/caching/caching_handler.py | 268 +++++++++++------- .../vertex_ai_context_caching.py | 6 +- litellm/types/caching.py | 25 ++ litellm/types/utils.py | 14 - 5 files changed, 217 insertions(+), 135 deletions(-) create mode 100644 litellm/types/caching.py diff --git a/litellm/caching/caching.py b/litellm/caching/caching.py index c16993625..088e2d03f 100644 --- a/litellm/caching/caching.py +++ b/litellm/caching/caching.py @@ -25,8 +25,9 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.types.caching import * from litellm.types.services import ServiceLoggerPayload, ServiceTypes -from litellm.types.utils import CachingSupportedCallTypes, all_litellm_params +from litellm.types.utils import all_litellm_params def print_verbose(print_statement): @@ -2125,9 +2126,7 @@ class DualCache(BaseCache): class Cache: def __init__( self, - type: Optional[ - Literal["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"] - ] = "local", + type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, mode: Optional[ CacheMode ] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in @@ -2221,7 +2220,7 @@ class Cache: Returns: None. Cache is set as a litellm param """ - if type == "redis": + if type == LiteLLMCacheType.REDIS: self.cache: BaseCache = RedisCache( host=host, port=port, @@ -2230,7 +2229,7 @@ class Cache: startup_nodes=redis_startup_nodes, **kwargs, ) - elif type == "redis-semantic": + elif type == LiteLLMCacheType.REDIS_SEMANTIC: self.cache = RedisSemanticCache( host=host, port=port, @@ -2240,7 +2239,7 @@ class Cache: embedding_model=redis_semantic_cache_embedding_model, **kwargs, ) - elif type == "qdrant-semantic": + elif type == LiteLLMCacheType.QDRANT_SEMANTIC: self.cache = QdrantSemanticCache( qdrant_api_base=qdrant_api_base, qdrant_api_key=qdrant_api_key, @@ -2249,9 +2248,9 @@ class Cache: quantization_config=qdrant_quantization_config, embedding_model=qdrant_semantic_cache_embedding_model, ) - elif type == "local": + elif type == LiteLLMCacheType.LOCAL: self.cache = InMemoryCache() - elif type == "s3": + elif type == LiteLLMCacheType.S3: self.cache = S3Cache( s3_bucket_name=s3_bucket_name, s3_region_name=s3_region_name, @@ -2266,7 +2265,7 @@ class Cache: s3_path=s3_path, **kwargs, ) - elif type == "disk": + elif type == LiteLLMCacheType.DISK: self.cache = DiskCache(disk_cache_dir=disk_cache_dir) if "cache" not in litellm.input_callback: litellm.input_callback.append("cache") @@ -2281,11 +2280,12 @@ class Cache: self.ttl = ttl self.mode: CacheMode = mode or CacheMode.default_on - if self.type == "local" and default_in_memory_ttl is not None: + if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None: self.ttl = default_in_memory_ttl if ( - self.type == "redis" or self.type == "redis-semantic" + self.type == LiteLLMCacheType.REDIS + or self.type == LiteLLMCacheType.REDIS_SEMANTIC ) and default_in_redis_ttl is not None: self.ttl = default_in_redis_ttl @@ -2694,6 +2694,17 @@ class Cache: if hasattr(self.cache, "disconnect"): await self.cache.disconnect() + def _supports_async(self) -> bool: + """ + Internal method to check if the cache type supports async get/set operations + + Only S3 Cache Does NOT support async operations + + """ + if self.type and self.type == LiteLLMCacheType.S3: + return False + return True + class DiskCache(BaseCache): def __init__(self, disk_cache_dir: Optional[str] = None): @@ -2774,7 +2785,7 @@ class DiskCache(BaseCache): def enable_cache( - type: Optional[Literal["local", "redis", "s3", "disk"]] = "local", + type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, host: Optional[str] = None, port: Optional[str] = None, password: Optional[str] = None, @@ -2832,7 +2843,7 @@ def enable_cache( def update_cache( - type: Optional[Literal["local", "redis", "s3", "disk"]] = "local", + type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, host: Optional[str] = None, port: Optional[str] = None, password: Optional[str] = None, diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 11f055ffe..264fb405b 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -99,11 +99,7 @@ class LLMCachingHandler: Raises: None """ - from litellm.utils import ( - CustomStreamWrapper, - convert_to_model_response_object, - convert_to_streaming_response_async, - ) + from litellm.utils import CustomStreamWrapper args = args or () @@ -124,43 +120,13 @@ class LLMCachingHandler: in litellm.cache.supported_call_types ): print_verbose("Checking Cache") - if call_type == CallTypes.aembedding.value and isinstance( - kwargs["input"], list - ): - tasks = [] - for idx, i in enumerate(kwargs["input"]): - preset_cache_key = litellm.cache.get_cache_key( - *args, **{**kwargs, "input": i} - ) - tasks.append( - litellm.cache.async_get_cache(cache_key=preset_cache_key) - ) - cached_result = await asyncio.gather(*tasks) - ## check if cached result is None ## - if cached_result is not None and isinstance(cached_result, list): - # set cached_result to None if all elements are None - if all(result is None for result in cached_result): - cached_result = None - elif isinstance(litellm.cache.cache, RedisSemanticCache) or isinstance( - litellm.cache.cache, RedisCache - ): - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key - ) - cached_result = await litellm.cache.async_get_cache(*args, **kwargs) - elif isinstance(litellm.cache.cache, QdrantSemanticCache): - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key - ) - cached_result = await litellm.cache.async_get_cache(*args, **kwargs) - else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync] - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key - ) - cached_result = litellm.cache.get_cache(*args, **kwargs) + + cached_result = await self._retrieve_from_cache( + call_type=call_type, + kwargs=kwargs, + args=args, + ) + if cached_result is not None and not isinstance(cached_result, list): print_verbose("Cache Hit!") cache_hit = True @@ -202,69 +168,16 @@ class LLMCachingHandler: stream=kwargs.get("stream", False), ) call_type = original_function.__name__ - if call_type == CallTypes.acompletion.value and isinstance( - cached_result, dict - ): - if kwargs.get("stream", False) is True: - cached_result = convert_to_streaming_response_async( - response_object=cached_result, - ) - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - else: - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=ModelResponse(), - ) - if call_type == CallTypes.atext_completion.value and isinstance( - cached_result, dict - ): - if kwargs.get("stream", False) is True: - cached_result = convert_to_streaming_response_async( - response_object=cached_result, - ) - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - else: - cached_result = TextCompletionResponse(**cached_result) - elif call_type == CallTypes.aembedding.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=EmbeddingResponse(), - response_type="embedding", - ) - elif call_type == CallTypes.arerank.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=None, - response_type="rerank", - ) - elif call_type == CallTypes.atranscription.value and isinstance( - cached_result, dict - ): - hidden_params = { - "model": "whisper-1", - "custom_llm_provider": custom_llm_provider, - "cache_hit": True, - } - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=TranscriptionResponse(), - response_type="audio_transcription", - hidden_params=hidden_params, - ) + + cached_result = self._convert_cached_result_to_model_response( + cached_result=cached_result, + call_type=call_type, + kwargs=kwargs, + logging_obj=logging_obj, + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + args=args, + ) if kwargs.get("stream", False) is False: # LOG SUCCESS asyncio.create_task( @@ -387,6 +300,151 @@ class LLMCachingHandler: final_embedding_cached_response=final_embedding_cached_response, ) + async def _retrieve_from_cache( + self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...] + ) -> Optional[Any]: + """ + Internal method to + - get cache key + - check what type of cache is used - Redis, RedisSemantic, Qdrant, S3 + - async get cache value + - return the cached value + + Args: + call_type: str: + kwargs: Dict[str, Any]: + args: Optional[Tuple[Any, ...]] = None: + + Returns: + Optional[Any]: + Raises: + None + """ + if litellm.cache is None: + return None + + cached_result: Optional[Any] = None + if call_type == CallTypes.aembedding.value and isinstance( + kwargs["input"], list + ): + tasks = [] + for idx, i in enumerate(kwargs["input"]): + preset_cache_key = litellm.cache.get_cache_key( + *args, **{**kwargs, "input": i} + ) + tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key)) + cached_result = await asyncio.gather(*tasks) + ## check if cached result is None ## + if cached_result is not None and isinstance(cached_result, list): + # set cached_result to None if all elements are None + if all(result is None for result in cached_result): + cached_result = None + else: + preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) + kwargs["preset_cache_key"] = ( + preset_cache_key # for streaming calls, we need to pass the preset_cache_key + ) + if litellm.cache._supports_async() is True: + cached_result = await litellm.cache.async_get_cache(*args, **kwargs) + else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync] + cached_result = litellm.cache.get_cache(*args, **kwargs) + return cached_result + + def _convert_cached_result_to_model_response( + self, + cached_result: Any, + call_type: str, + kwargs: Dict[str, Any], + logging_obj: LiteLLMLoggingObj, + model: str, + args: Tuple[Any, ...], + custom_llm_provider: Optional[str] = None, + ) -> Optional[Any]: + """ + Internal method to process the cached result + + Checks the call type and converts the cached result to the appropriate model response object + example if call type is text_completion -> returns TextCompletionResponse object + + Args: + cached_result: Any: + call_type: str: + kwargs: Dict[str, Any]: + logging_obj: LiteLLMLoggingObj: + model: str: + custom_llm_provider: Optional[str] = None: + args: Optional[Tuple[Any, ...]] = None: + + Returns: + Optional[Any]: + """ + from litellm.utils import ( + CustomStreamWrapper, + convert_to_model_response_object, + convert_to_streaming_response_async, + ) + + if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict): + if kwargs.get("stream", False) is True: + cached_result = convert_to_streaming_response_async( + response_object=cached_result, + ) + cached_result = CustomStreamWrapper( + completion_stream=cached_result, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + else: + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=ModelResponse(), + ) + if call_type == CallTypes.atext_completion.value and isinstance( + cached_result, dict + ): + if kwargs.get("stream", False) is True: + cached_result = convert_to_streaming_response_async( + response_object=cached_result, + ) + cached_result = CustomStreamWrapper( + completion_stream=cached_result, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + else: + cached_result = TextCompletionResponse(**cached_result) + elif call_type == CallTypes.aembedding.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=EmbeddingResponse(), + response_type="embedding", + ) + elif call_type == CallTypes.arerank.value and isinstance(cached_result, dict): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=None, + response_type="rerank", + ) + elif call_type == CallTypes.atranscription.value and isinstance( + cached_result, dict + ): + hidden_params = { + "model": "whisper-1", + "custom_llm_provider": custom_llm_provider, + "cache_hit": True, + } + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=TranscriptionResponse(), + response_type="audio_transcription", + hidden_params=hidden_params, + ) + return cached_result + async def _async_set_cache( self, result: Any, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py index 2dafce6a9..e60a17052 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/vertex_ai_context_caching.py @@ -4,7 +4,7 @@ from typing import Callable, List, Literal, Optional, Tuple, Union import httpx import litellm -from litellm.caching.caching import Cache +from litellm.caching.caching import Cache, LiteLLMCacheType from litellm.litellm_core_utils.litellm_logging import Logging from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.OpenAI.openai import AllMessageValues @@ -22,7 +22,9 @@ from .transformation import ( transform_openai_messages_to_gemini_context_caching, ) -local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function +local_cache_obj = Cache( + type=LiteLLMCacheType.LOCAL +) # only used for calling 'get_cache_key' function class ContextCachingEndpoints(VertexBase): diff --git a/litellm/types/caching.py b/litellm/types/caching.py new file mode 100644 index 000000000..7fca4c041 --- /dev/null +++ b/litellm/types/caching.py @@ -0,0 +1,25 @@ +from enum import Enum +from typing import Literal + + +class LiteLLMCacheType(str, Enum): + LOCAL = "local" + REDIS = "redis" + REDIS_SEMANTIC = "redis-semantic" + S3 = "s3" + DISK = "disk" + QDRANT_SEMANTIC = "qdrant-semantic" + + +CachingSupportedCallTypes = Literal[ + "completion", + "acompletion", + "embedding", + "aembedding", + "atranscription", + "transcription", + "atext_completion", + "text_completion", + "arerank", + "rerank", +] diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 2a36dd84d..c3118b453 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1418,17 +1418,3 @@ class StandardCallbackDynamicParams(TypedDict, total=False): # GCS dynamic params gcs_bucket_name: Optional[str] gcs_path_service_account: Optional[str] - - -CachingSupportedCallTypes = Literal[ - "completion", - "acompletion", - "embedding", - "aembedding", - "atranscription", - "transcription", - "atext_completion", - "text_completion", - "arerank", - "rerank", -]