mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(refactor caching) use common _retrieve_from_cache
helper (#6212)
* use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * refactor async set stream cache * fix linting * refactor - use _retrieve_from_cache * refactor use _convert_cached_result_to_model_response * fix linting errors
This commit is contained in:
parent
284deafd0d
commit
4ebe6813c0
5 changed files with 217 additions and 135 deletions
|
@ -99,11 +99,7 @@ class LLMCachingHandler:
|
|||
Raises:
|
||||
None
|
||||
"""
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
convert_to_streaming_response_async,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
args = args or ()
|
||||
|
||||
|
@ -124,43 +120,13 @@ class LLMCachingHandler:
|
|||
in litellm.cache.supported_call_types
|
||||
):
|
||||
print_verbose("Checking Cache")
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
kwargs["input"], list
|
||||
):
|
||||
tasks = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
)
|
||||
tasks.append(
|
||||
litellm.cache.async_get_cache(cache_key=preset_cache_key)
|
||||
)
|
||||
cached_result = await asyncio.gather(*tasks)
|
||||
## check if cached result is None ##
|
||||
if cached_result is not None and isinstance(cached_result, list):
|
||||
# set cached_result to None if all elements are None
|
||||
if all(result is None for result in cached_result):
|
||||
cached_result = None
|
||||
elif isinstance(litellm.cache.cache, RedisSemanticCache) or isinstance(
|
||||
litellm.cache.cache, RedisCache
|
||||
):
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
|
||||
elif isinstance(litellm.cache.cache, QdrantSemanticCache):
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
|
||||
cached_result = await self._retrieve_from_cache(
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
args=args,
|
||||
)
|
||||
|
||||
if cached_result is not None and not isinstance(cached_result, list):
|
||||
print_verbose("Cache Hit!")
|
||||
cache_hit = True
|
||||
|
@ -202,69 +168,16 @@ class LLMCachingHandler:
|
|||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
if call_type == CallTypes.atext_completion.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = TextCompletionResponse(**cached_result)
|
||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
elif call_type == CallTypes.arerank.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=None,
|
||||
response_type="rerank",
|
||||
)
|
||||
elif call_type == CallTypes.atranscription.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
hidden_params = {
|
||||
"model": "whisper-1",
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"cache_hit": True,
|
||||
}
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=TranscriptionResponse(),
|
||||
response_type="audio_transcription",
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
|
||||
cached_result = self._convert_cached_result_to_model_response(
|
||||
cached_result=cached_result,
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
args=args,
|
||||
)
|
||||
if kwargs.get("stream", False) is False:
|
||||
# LOG SUCCESS
|
||||
asyncio.create_task(
|
||||
|
@ -387,6 +300,151 @@ class LLMCachingHandler:
|
|||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
)
|
||||
|
||||
async def _retrieve_from_cache(
|
||||
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Internal method to
|
||||
- get cache key
|
||||
- check what type of cache is used - Redis, RedisSemantic, Qdrant, S3
|
||||
- async get cache value
|
||||
- return the cached value
|
||||
|
||||
Args:
|
||||
call_type: str:
|
||||
kwargs: Dict[str, Any]:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
Returns:
|
||||
Optional[Any]:
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
if litellm.cache is None:
|
||||
return None
|
||||
|
||||
cached_result: Optional[Any] = None
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
kwargs["input"], list
|
||||
):
|
||||
tasks = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
)
|
||||
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
|
||||
cached_result = await asyncio.gather(*tasks)
|
||||
## check if cached result is None ##
|
||||
if cached_result is not None and isinstance(cached_result, list):
|
||||
# set cached_result to None if all elements are None
|
||||
if all(result is None for result in cached_result):
|
||||
cached_result = None
|
||||
else:
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
if litellm.cache._supports_async() is True:
|
||||
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
return cached_result
|
||||
|
||||
def _convert_cached_result_to_model_response(
|
||||
self,
|
||||
cached_result: Any,
|
||||
call_type: str,
|
||||
kwargs: Dict[str, Any],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
args: Tuple[Any, ...],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Internal method to process the cached result
|
||||
|
||||
Checks the call type and converts the cached result to the appropriate model response object
|
||||
example if call type is text_completion -> returns TextCompletionResponse object
|
||||
|
||||
Args:
|
||||
cached_result: Any:
|
||||
call_type: str:
|
||||
kwargs: Dict[str, Any]:
|
||||
logging_obj: LiteLLMLoggingObj:
|
||||
model: str:
|
||||
custom_llm_provider: Optional[str] = None:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
Returns:
|
||||
Optional[Any]:
|
||||
"""
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
convert_to_streaming_response_async,
|
||||
)
|
||||
|
||||
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
if call_type == CallTypes.atext_completion.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = TextCompletionResponse(**cached_result)
|
||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
elif call_type == CallTypes.arerank.value and isinstance(cached_result, dict):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=None,
|
||||
response_type="rerank",
|
||||
)
|
||||
elif call_type == CallTypes.atranscription.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
hidden_params = {
|
||||
"model": "whisper-1",
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"cache_hit": True,
|
||||
}
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=TranscriptionResponse(),
|
||||
response_type="audio_transcription",
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
return cached_result
|
||||
|
||||
async def _async_set_cache(
|
||||
self,
|
||||
result: Any,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue