(refactor) caching use LLMCachingHandler for async_get_cache and set_cache (#6208)

* use folder for caching

* fix importing caching

* fix clickhouse pyright

* fix linting

* fix correctly pass kwargs and args

* fix test case for embedding

* fix linting

* fix embedding caching logic

* fix refactor handle utils.py

* fix test_embedding_caching_azure_individual_items_reordered
This commit is contained in:
Ishaan Jaff 2024-10-14 16:34:01 +05:30 committed by GitHub
parent 20e50d7002
commit 4d1b4beb3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
96 changed files with 690 additions and 489 deletions

View file

@ -56,7 +56,10 @@ import litellm._service_logger # for storing API inputs, outputs, and metadata
import litellm.litellm_core_utils
import litellm.litellm_core_utils.audio_utils.utils
import litellm.litellm_core_utils.json_validation_rule
from litellm.caching import DualCache
from litellm.caching.caching import DualCache
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
_llm_caching_handler = LLMCachingHandler()
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.exception_mapping_utils import (
@ -146,7 +149,13 @@ from typing import (
from openai import OpenAIError as OriginalError
from ._logging import verbose_logger
from .caching import Cache, QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
from .caching.caching import (
Cache,
QdrantSemanticCache,
RedisCache,
RedisSemanticCache,
S3Cache,
)
from .exceptions import (
APIConnectionError,
APIError,
@ -1121,299 +1130,26 @@ def client(original_function):
print_verbose(
f"ASYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache'): {kwargs.get('cache', None)}"
)
# if caching is false, don't run this
final_embedding_cached_response = None
_caching_handler_response: CachingHandlerResponse = (
await _llm_caching_handler._async_get_cache(
model=model,
original_function=original_function,
logging_obj=logging_obj,
start_time=start_time,
call_type=call_type,
kwargs=kwargs,
args=args,
)
)
if (
(kwargs.get("caching", None) is None and litellm.cache is not None)
or kwargs.get("caching", False) is True
) and (
kwargs.get("cache", {}).get("no-cache", False) is not True
): # allow users to control returning cached responses from the completion function
# checking cache
print_verbose("INSIDE CHECKING CACHE")
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
):
print_verbose("Checking Cache")
if call_type == CallTypes.aembedding.value and isinstance(
kwargs["input"], list
):
tasks = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
tasks.append(
litellm.cache.async_get_cache(
cache_key=preset_cache_key
)
)
cached_result = await asyncio.gather(*tasks)
## check if cached result is None ##
if cached_result is not None and isinstance(
cached_result, list
):
if len(cached_result) == 1 and cached_result[0] is None:
cached_result = None
elif isinstance(
litellm.cache.cache, RedisSemanticCache
) or isinstance(litellm.cache.cache, RedisCache):
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = await litellm.cache.async_get_cache(
*args, **kwargs
)
elif isinstance(litellm.cache.cache, QdrantSemanticCache):
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = await litellm.cache.async_get_cache(
*args, **kwargs
)
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None and not isinstance(
cached_result, list
):
print_verbose("Cache Hit!", log_level="INFO")
cache_hit = True
end_time = datetime.datetime.now()
(
model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
print_verbose(
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
)
logging_obj.update_environment_variables(
model=model,
user=kwargs.get("user", None),
optional_params={},
litellm_params={
"logger_fn": kwargs.get("logger_fn", None),
"acompletion": True,
"metadata": kwargs.get("metadata", {}),
"model_info": kwargs.get("model_info", {}),
"proxy_server_request": kwargs.get(
"proxy_server_request", None
),
"preset_cache_key": kwargs.get(
"preset_cache_key", None
),
"stream_response": kwargs.get("stream_response", {}),
"api_base": kwargs.get("api_base", ""),
},
input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None),
original_response=str(cached_result),
additional_args=None,
stream=kwargs.get("stream", False),
)
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(
cached_result, dict
):
if kwargs.get("stream", False) is True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
if (
call_type == CallTypes.atext_completion.value
and isinstance(cached_result, dict)
):
if kwargs.get("stream", False) is True:
cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
cached_result = TextCompletionResponse(**cached_result)
elif call_type == CallTypes.aembedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
elif call_type == CallTypes.arerank.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=None,
response_type="rerank",
)
elif call_type == CallTypes.atranscription.value and isinstance(
cached_result, dict
):
hidden_params = {
"model": "whisper-1",
"custom_llm_provider": custom_llm_provider,
"cache_hit": True,
}
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=TranscriptionResponse(),
response_type="audio_transcription",
hidden_params=hidden_params,
)
if kwargs.get("stream", False) is False:
# LOG SUCCESS
asyncio.create_task(
logging_obj.async_success_handler(
cached_result, start_time, end_time, cache_hit
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
cache_key = kwargs.get("preset_cache_key", None)
if (
isinstance(cached_result, BaseModel)
or isinstance(cached_result, CustomStreamWrapper)
) and hasattr(cached_result, "_hidden_params"):
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
return cached_result
elif (
call_type == CallTypes.aembedding.value
and cached_result is not None
and isinstance(cached_result, list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
):
remaining_list = []
non_null_list = []
for idx, cr in enumerate(cached_result):
if cr is None:
remaining_list.append(kwargs["input"][idx])
else:
non_null_list.append((idx, cr))
original_kwargs_input = kwargs["input"]
kwargs["input"] = remaining_list
if len(non_null_list) > 0:
print_verbose(
f"EMBEDDING CACHE HIT! - {len(non_null_list)}"
)
final_embedding_cached_response = EmbeddingResponse(
model=kwargs.get("model"),
data=[None] * len(original_kwargs_input),
)
final_embedding_cached_response._hidden_params[
"cache_hit"
] = True
_caching_handler_response.cached_result is not None
and _caching_handler_response.final_embedding_cached_response is None
):
return _caching_handler_response.cached_result
elif _caching_handler_response.embedding_all_elements_cache_hit is True:
return _caching_handler_response.final_embedding_cached_response
for val in non_null_list:
idx, cr = val # (idx, cr) tuple
if cr is not None:
final_embedding_cached_response.data[idx] = (
Embedding(
embedding=cr["embedding"],
index=idx,
object="embedding",
)
)
if len(remaining_list) == 0:
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
(
model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model,
custom_llm_provider=kwargs.get(
"custom_llm_provider", None
),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
print_verbose(
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
)
logging_obj.update_environment_variables(
model=model,
user=kwargs.get("user", None),
optional_params={},
litellm_params={
"logger_fn": kwargs.get("logger_fn", None),
"acompletion": True,
"metadata": kwargs.get("metadata", {}),
"model_info": kwargs.get("model_info", {}),
"proxy_server_request": kwargs.get(
"proxy_server_request", None
),
"preset_cache_key": kwargs.get(
"preset_cache_key", None
),
"stream_response": kwargs.get(
"stream_response", {}
),
"api_base": "",
},
input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None),
original_response=str(final_embedding_cached_response),
additional_args=None,
stream=kwargs.get("stream", False),
)
asyncio.create_task(
logging_obj.async_success_handler(
final_embedding_cached_response,
start_time,
end_time,
cache_hit,
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(
final_embedding_cached_response,
start_time,
end_time,
cache_hit,
),
).start()
return final_embedding_cached_response
# MODEL CALL
result = await original_function(*args, **kwargs)
end_time = datetime.datetime.now()
@ -1467,51 +1203,14 @@ def client(original_function):
original_response=result, model=model, optional_params=kwargs
)
# [OPTIONAL] ADD TO CACHE
if (
(litellm.cache is not None)
and litellm.cache.supported_call_types is not None
and (
str(original_function.__name__)
in litellm.cache.supported_call_types
)
and (kwargs.get("cache", {}).get("no-store", False) is not True)
):
if (
isinstance(result, litellm.ModelResponse)
or isinstance(result, litellm.EmbeddingResponse)
or isinstance(result, TranscriptionResponse)
or isinstance(result, RerankResponse)
):
if (
isinstance(result, EmbeddingResponse)
and isinstance(kwargs["input"], list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
):
asyncio.create_task(
litellm.cache.async_add_cache_pipeline(
result, *args, **kwargs
)
)
elif isinstance(litellm.cache.cache, S3Cache):
threading.Thread(
target=litellm.cache.add_cache,
args=(result,) + args,
kwargs=kwargs,
).start()
else:
asyncio.create_task(
litellm.cache.async_add_cache(
result.json(), *args, **kwargs
)
)
else:
asyncio.create_task(
litellm.cache.async_add_cache(result, *args, **kwargs)
)
## Add response to cache
await _llm_caching_handler._async_set_cache(
result=result,
original_function=original_function,
kwargs=kwargs,
args=args,
)
# LOG SUCCESS - handle streaming success logging in the _next_ object
print_verbose(
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
@ -1528,24 +1227,32 @@ def client(original_function):
# REBUILD EMBEDDING CACHING
if (
isinstance(result, EmbeddingResponse)
and final_embedding_cached_response is not None
and final_embedding_cached_response.data is not None
and _caching_handler_response.final_embedding_cached_response
is not None
and _caching_handler_response.final_embedding_cached_response.data
is not None
):
idx = 0
final_data_list = []
for item in final_embedding_cached_response.data:
for (
item
) in _caching_handler_response.final_embedding_cached_response.data:
if item is None and result.data is not None:
final_data_list.append(result.data[idx])
idx += 1
else:
final_data_list.append(item)
final_embedding_cached_response.data = final_data_list
final_embedding_cached_response._hidden_params["cache_hit"] = True
final_embedding_cached_response._response_ms = (
_caching_handler_response.final_embedding_cached_response.data = (
final_data_list
)
_caching_handler_response.final_embedding_cached_response._hidden_params[
"cache_hit"
] = True
_caching_handler_response.final_embedding_cached_response._response_ms = (
end_time - start_time
).total_seconds() * 1000
return final_embedding_cached_response
return _caching_handler_response.final_embedding_cached_response
return result
except Exception as e: