forked from phoenix/litellm-mirror
(refactor caching) use common _retrieve_from_cache
helper (#6212)
* use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * refactor async set stream cache * fix linting * refactor - use _retrieve_from_cache * refactor use _convert_cached_result_to_model_response * fix linting errors
This commit is contained in:
parent
284deafd0d
commit
4ebe6813c0
5 changed files with 217 additions and 135 deletions
|
@ -25,8 +25,9 @@ from pydantic import BaseModel
|
|||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.types.caching import *
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
from litellm.types.utils import CachingSupportedCallTypes, all_litellm_params
|
||||
from litellm.types.utils import all_litellm_params
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
|
@ -2125,9 +2126,7 @@ class DualCache(BaseCache):
|
|||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[
|
||||
Literal["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"]
|
||||
] = "local",
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
mode: Optional[
|
||||
CacheMode
|
||||
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
|
||||
|
@ -2221,7 +2220,7 @@ class Cache:
|
|||
Returns:
|
||||
None. Cache is set as a litellm param
|
||||
"""
|
||||
if type == "redis":
|
||||
if type == LiteLLMCacheType.REDIS:
|
||||
self.cache: BaseCache = RedisCache(
|
||||
host=host,
|
||||
port=port,
|
||||
|
@ -2230,7 +2229,7 @@ class Cache:
|
|||
startup_nodes=redis_startup_nodes,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "redis-semantic":
|
||||
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
|
||||
self.cache = RedisSemanticCache(
|
||||
host=host,
|
||||
port=port,
|
||||
|
@ -2240,7 +2239,7 @@ class Cache:
|
|||
embedding_model=redis_semantic_cache_embedding_model,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "qdrant-semantic":
|
||||
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
|
||||
self.cache = QdrantSemanticCache(
|
||||
qdrant_api_base=qdrant_api_base,
|
||||
qdrant_api_key=qdrant_api_key,
|
||||
|
@ -2249,9 +2248,9 @@ class Cache:
|
|||
quantization_config=qdrant_quantization_config,
|
||||
embedding_model=qdrant_semantic_cache_embedding_model,
|
||||
)
|
||||
elif type == "local":
|
||||
elif type == LiteLLMCacheType.LOCAL:
|
||||
self.cache = InMemoryCache()
|
||||
elif type == "s3":
|
||||
elif type == LiteLLMCacheType.S3:
|
||||
self.cache = S3Cache(
|
||||
s3_bucket_name=s3_bucket_name,
|
||||
s3_region_name=s3_region_name,
|
||||
|
@ -2266,7 +2265,7 @@ class Cache:
|
|||
s3_path=s3_path,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "disk":
|
||||
elif type == LiteLLMCacheType.DISK:
|
||||
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
|
@ -2281,11 +2280,12 @@ class Cache:
|
|||
self.ttl = ttl
|
||||
self.mode: CacheMode = mode or CacheMode.default_on
|
||||
|
||||
if self.type == "local" and default_in_memory_ttl is not None:
|
||||
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
|
||||
self.ttl = default_in_memory_ttl
|
||||
|
||||
if (
|
||||
self.type == "redis" or self.type == "redis-semantic"
|
||||
self.type == LiteLLMCacheType.REDIS
|
||||
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
|
||||
) and default_in_redis_ttl is not None:
|
||||
self.ttl = default_in_redis_ttl
|
||||
|
||||
|
@ -2694,6 +2694,17 @@ class Cache:
|
|||
if hasattr(self.cache, "disconnect"):
|
||||
await self.cache.disconnect()
|
||||
|
||||
def _supports_async(self) -> bool:
|
||||
"""
|
||||
Internal method to check if the cache type supports async get/set operations
|
||||
|
||||
Only S3 Cache Does NOT support async operations
|
||||
|
||||
"""
|
||||
if self.type and self.type == LiteLLMCacheType.S3:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class DiskCache(BaseCache):
|
||||
def __init__(self, disk_cache_dir: Optional[str] = None):
|
||||
|
@ -2774,7 +2785,7 @@ class DiskCache(BaseCache):
|
|||
|
||||
|
||||
def enable_cache(
|
||||
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
|
@ -2832,7 +2843,7 @@ def enable_cache(
|
|||
|
||||
|
||||
def update_cache(
|
||||
type: Optional[Literal["local", "redis", "s3", "disk"]] = "local",
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
|
|
|
@ -99,11 +99,7 @@ class LLMCachingHandler:
|
|||
Raises:
|
||||
None
|
||||
"""
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
convert_to_streaming_response_async,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
args = args or ()
|
||||
|
||||
|
@ -124,43 +120,13 @@ class LLMCachingHandler:
|
|||
in litellm.cache.supported_call_types
|
||||
):
|
||||
print_verbose("Checking Cache")
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
kwargs["input"], list
|
||||
):
|
||||
tasks = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
)
|
||||
tasks.append(
|
||||
litellm.cache.async_get_cache(cache_key=preset_cache_key)
|
||||
)
|
||||
cached_result = await asyncio.gather(*tasks)
|
||||
## check if cached result is None ##
|
||||
if cached_result is not None and isinstance(cached_result, list):
|
||||
# set cached_result to None if all elements are None
|
||||
if all(result is None for result in cached_result):
|
||||
cached_result = None
|
||||
elif isinstance(litellm.cache.cache, RedisSemanticCache) or isinstance(
|
||||
litellm.cache.cache, RedisCache
|
||||
):
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
|
||||
elif isinstance(litellm.cache.cache, QdrantSemanticCache):
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
|
||||
cached_result = await self._retrieve_from_cache(
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
args=args,
|
||||
)
|
||||
|
||||
if cached_result is not None and not isinstance(cached_result, list):
|
||||
print_verbose("Cache Hit!")
|
||||
cache_hit = True
|
||||
|
@ -202,69 +168,16 @@ class LLMCachingHandler:
|
|||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
if call_type == CallTypes.atext_completion.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = TextCompletionResponse(**cached_result)
|
||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
elif call_type == CallTypes.arerank.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=None,
|
||||
response_type="rerank",
|
||||
)
|
||||
elif call_type == CallTypes.atranscription.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
hidden_params = {
|
||||
"model": "whisper-1",
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"cache_hit": True,
|
||||
}
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=TranscriptionResponse(),
|
||||
response_type="audio_transcription",
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
|
||||
cached_result = self._convert_cached_result_to_model_response(
|
||||
cached_result=cached_result,
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
args=args,
|
||||
)
|
||||
if kwargs.get("stream", False) is False:
|
||||
# LOG SUCCESS
|
||||
asyncio.create_task(
|
||||
|
@ -387,6 +300,151 @@ class LLMCachingHandler:
|
|||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
)
|
||||
|
||||
async def _retrieve_from_cache(
|
||||
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Internal method to
|
||||
- get cache key
|
||||
- check what type of cache is used - Redis, RedisSemantic, Qdrant, S3
|
||||
- async get cache value
|
||||
- return the cached value
|
||||
|
||||
Args:
|
||||
call_type: str:
|
||||
kwargs: Dict[str, Any]:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
Returns:
|
||||
Optional[Any]:
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
if litellm.cache is None:
|
||||
return None
|
||||
|
||||
cached_result: Optional[Any] = None
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
kwargs["input"], list
|
||||
):
|
||||
tasks = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
)
|
||||
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
|
||||
cached_result = await asyncio.gather(*tasks)
|
||||
## check if cached result is None ##
|
||||
if cached_result is not None and isinstance(cached_result, list):
|
||||
# set cached_result to None if all elements are None
|
||||
if all(result is None for result in cached_result):
|
||||
cached_result = None
|
||||
else:
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
if litellm.cache._supports_async() is True:
|
||||
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
return cached_result
|
||||
|
||||
def _convert_cached_result_to_model_response(
|
||||
self,
|
||||
cached_result: Any,
|
||||
call_type: str,
|
||||
kwargs: Dict[str, Any],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
args: Tuple[Any, ...],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Internal method to process the cached result
|
||||
|
||||
Checks the call type and converts the cached result to the appropriate model response object
|
||||
example if call type is text_completion -> returns TextCompletionResponse object
|
||||
|
||||
Args:
|
||||
cached_result: Any:
|
||||
call_type: str:
|
||||
kwargs: Dict[str, Any]:
|
||||
logging_obj: LiteLLMLoggingObj:
|
||||
model: str:
|
||||
custom_llm_provider: Optional[str] = None:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
Returns:
|
||||
Optional[Any]:
|
||||
"""
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
convert_to_streaming_response_async,
|
||||
)
|
||||
|
||||
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
if call_type == CallTypes.atext_completion.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = TextCompletionResponse(**cached_result)
|
||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
elif call_type == CallTypes.arerank.value and isinstance(cached_result, dict):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=None,
|
||||
response_type="rerank",
|
||||
)
|
||||
elif call_type == CallTypes.atranscription.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
hidden_params = {
|
||||
"model": "whisper-1",
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"cache_hit": True,
|
||||
}
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=TranscriptionResponse(),
|
||||
response_type="audio_transcription",
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
return cached_result
|
||||
|
||||
async def _async_set_cache(
|
||||
self,
|
||||
result: Any,
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Callable, List, Literal, Optional, Tuple, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.caching.caching import Cache
|
||||
from litellm.caching.caching import Cache, LiteLLMCacheType
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.OpenAI.openai import AllMessageValues
|
||||
|
@ -22,7 +22,9 @@ from .transformation import (
|
|||
transform_openai_messages_to_gemini_context_caching,
|
||||
)
|
||||
|
||||
local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function
|
||||
local_cache_obj = Cache(
|
||||
type=LiteLLMCacheType.LOCAL
|
||||
) # only used for calling 'get_cache_key' function
|
||||
|
||||
|
||||
class ContextCachingEndpoints(VertexBase):
|
||||
|
|
25
litellm/types/caching.py
Normal file
25
litellm/types/caching.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class LiteLLMCacheType(str, Enum):
|
||||
LOCAL = "local"
|
||||
REDIS = "redis"
|
||||
REDIS_SEMANTIC = "redis-semantic"
|
||||
S3 = "s3"
|
||||
DISK = "disk"
|
||||
QDRANT_SEMANTIC = "qdrant-semantic"
|
||||
|
||||
|
||||
CachingSupportedCallTypes = Literal[
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
]
|
|
@ -1418,17 +1418,3 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
|
|||
# GCS dynamic params
|
||||
gcs_bucket_name: Optional[str]
|
||||
gcs_path_service_account: Optional[str]
|
||||
|
||||
|
||||
CachingSupportedCallTypes = Literal[
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue