mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(refactor) caching use LLMCachingHandler for async_get_cache and set_cache (#6208)
* use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * fix test_embedding_caching_azure_individual_items_reordered
This commit is contained in:
parent
20e50d7002
commit
4d1b4beb3d
96 changed files with 690 additions and 489 deletions
397
litellm/utils.py
397
litellm/utils.py
|
@ -56,7 +56,10 @@ import litellm._service_logger # for storing API inputs, outputs, and metadata
|
|||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.audio_utils.utils
|
||||
import litellm.litellm_core_utils.json_validation_rule
|
||||
from litellm.caching import DualCache
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
|
||||
|
||||
_llm_caching_handler = LLMCachingHandler()
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.exception_mapping_utils import (
|
||||
|
@ -146,7 +149,13 @@ from typing import (
|
|||
from openai import OpenAIError as OriginalError
|
||||
|
||||
from ._logging import verbose_logger
|
||||
from .caching import Cache, QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
|
||||
from .caching.caching import (
|
||||
Cache,
|
||||
QdrantSemanticCache,
|
||||
RedisCache,
|
||||
RedisSemanticCache,
|
||||
S3Cache,
|
||||
)
|
||||
from .exceptions import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
|
@ -1121,299 +1130,26 @@ def client(original_function):
|
|||
print_verbose(
|
||||
f"ASYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache'): {kwargs.get('cache', None)}"
|
||||
)
|
||||
# if caching is false, don't run this
|
||||
final_embedding_cached_response = None
|
||||
|
||||
_caching_handler_response: CachingHandlerResponse = (
|
||||
await _llm_caching_handler._async_get_cache(
|
||||
model=model,
|
||||
original_function=original_function,
|
||||
logging_obj=logging_obj,
|
||||
start_time=start_time,
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
args=args,
|
||||
)
|
||||
)
|
||||
if (
|
||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||
or kwargs.get("caching", False) is True
|
||||
) and (
|
||||
kwargs.get("cache", {}).get("no-cache", False) is not True
|
||||
): # allow users to control returning cached responses from the completion function
|
||||
# checking cache
|
||||
print_verbose("INSIDE CHECKING CACHE")
|
||||
if (
|
||||
litellm.cache is not None
|
||||
and litellm.cache.supported_call_types is not None
|
||||
and str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
):
|
||||
print_verbose("Checking Cache")
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
kwargs["input"], list
|
||||
):
|
||||
tasks = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
)
|
||||
tasks.append(
|
||||
litellm.cache.async_get_cache(
|
||||
cache_key=preset_cache_key
|
||||
)
|
||||
)
|
||||
cached_result = await asyncio.gather(*tasks)
|
||||
## check if cached result is None ##
|
||||
if cached_result is not None and isinstance(
|
||||
cached_result, list
|
||||
):
|
||||
if len(cached_result) == 1 and cached_result[0] is None:
|
||||
cached_result = None
|
||||
elif isinstance(
|
||||
litellm.cache.cache, RedisSemanticCache
|
||||
) or isinstance(litellm.cache.cache, RedisCache):
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = await litellm.cache.async_get_cache(
|
||||
*args, **kwargs
|
||||
)
|
||||
elif isinstance(litellm.cache.cache, QdrantSemanticCache):
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = await litellm.cache.async_get_cache(
|
||||
*args, **kwargs
|
||||
)
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
)
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
if cached_result is not None and not isinstance(
|
||||
cached_result, list
|
||||
):
|
||||
print_verbose("Cache Hit!", log_level="INFO")
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
(
|
||||
model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
api_base=kwargs.get("api_base", None),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
)
|
||||
print_verbose(
|
||||
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
|
||||
)
|
||||
logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=kwargs.get("user", None),
|
||||
optional_params={},
|
||||
litellm_params={
|
||||
"logger_fn": kwargs.get("logger_fn", None),
|
||||
"acompletion": True,
|
||||
"metadata": kwargs.get("metadata", {}),
|
||||
"model_info": kwargs.get("model_info", {}),
|
||||
"proxy_server_request": kwargs.get(
|
||||
"proxy_server_request", None
|
||||
),
|
||||
"preset_cache_key": kwargs.get(
|
||||
"preset_cache_key", None
|
||||
),
|
||||
"stream_response": kwargs.get("stream_response", {}),
|
||||
"api_base": kwargs.get("api_base", ""),
|
||||
},
|
||||
input=kwargs.get("messages", ""),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
original_response=str(cached_result),
|
||||
additional_args=None,
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
if (
|
||||
call_type == CallTypes.atext_completion.value
|
||||
and isinstance(cached_result, dict)
|
||||
):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = TextCompletionResponse(**cached_result)
|
||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
elif call_type == CallTypes.arerank.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=None,
|
||||
response_type="rerank",
|
||||
)
|
||||
elif call_type == CallTypes.atranscription.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
hidden_params = {
|
||||
"model": "whisper-1",
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"cache_hit": True,
|
||||
}
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=TranscriptionResponse(),
|
||||
response_type="audio_transcription",
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
if kwargs.get("stream", False) is False:
|
||||
# LOG SUCCESS
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
cached_result, start_time, end_time, cache_hit
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
cache_key = kwargs.get("preset_cache_key", None)
|
||||
if (
|
||||
isinstance(cached_result, BaseModel)
|
||||
or isinstance(cached_result, CustomStreamWrapper)
|
||||
) and hasattr(cached_result, "_hidden_params"):
|
||||
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
|
||||
return cached_result
|
||||
elif (
|
||||
call_type == CallTypes.aembedding.value
|
||||
and cached_result is not None
|
||||
and isinstance(cached_result, list)
|
||||
and litellm.cache is not None
|
||||
and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
) # s3 doesn't support bulk writing. Exclude.
|
||||
):
|
||||
remaining_list = []
|
||||
non_null_list = []
|
||||
for idx, cr in enumerate(cached_result):
|
||||
if cr is None:
|
||||
remaining_list.append(kwargs["input"][idx])
|
||||
else:
|
||||
non_null_list.append((idx, cr))
|
||||
original_kwargs_input = kwargs["input"]
|
||||
kwargs["input"] = remaining_list
|
||||
if len(non_null_list) > 0:
|
||||
print_verbose(
|
||||
f"EMBEDDING CACHE HIT! - {len(non_null_list)}"
|
||||
)
|
||||
final_embedding_cached_response = EmbeddingResponse(
|
||||
model=kwargs.get("model"),
|
||||
data=[None] * len(original_kwargs_input),
|
||||
)
|
||||
final_embedding_cached_response._hidden_params[
|
||||
"cache_hit"
|
||||
] = True
|
||||
_caching_handler_response.cached_result is not None
|
||||
and _caching_handler_response.final_embedding_cached_response is None
|
||||
):
|
||||
return _caching_handler_response.cached_result
|
||||
|
||||
elif _caching_handler_response.embedding_all_elements_cache_hit is True:
|
||||
return _caching_handler_response.final_embedding_cached_response
|
||||
|
||||
for val in non_null_list:
|
||||
idx, cr = val # (idx, cr) tuple
|
||||
if cr is not None:
|
||||
final_embedding_cached_response.data[idx] = (
|
||||
Embedding(
|
||||
embedding=cr["embedding"],
|
||||
index=idx,
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
if len(remaining_list) == 0:
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
(
|
||||
model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
api_base=kwargs.get("api_base", None),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
)
|
||||
print_verbose(
|
||||
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
|
||||
)
|
||||
logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=kwargs.get("user", None),
|
||||
optional_params={},
|
||||
litellm_params={
|
||||
"logger_fn": kwargs.get("logger_fn", None),
|
||||
"acompletion": True,
|
||||
"metadata": kwargs.get("metadata", {}),
|
||||
"model_info": kwargs.get("model_info", {}),
|
||||
"proxy_server_request": kwargs.get(
|
||||
"proxy_server_request", None
|
||||
),
|
||||
"preset_cache_key": kwargs.get(
|
||||
"preset_cache_key", None
|
||||
),
|
||||
"stream_response": kwargs.get(
|
||||
"stream_response", {}
|
||||
),
|
||||
"api_base": "",
|
||||
},
|
||||
input=kwargs.get("messages", ""),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
original_response=str(final_embedding_cached_response),
|
||||
additional_args=None,
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
final_embedding_cached_response,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(
|
||||
final_embedding_cached_response,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
),
|
||||
).start()
|
||||
return final_embedding_cached_response
|
||||
# MODEL CALL
|
||||
result = await original_function(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
|
@ -1467,51 +1203,14 @@ def client(original_function):
|
|||
original_response=result, model=model, optional_params=kwargs
|
||||
)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if (
|
||||
(litellm.cache is not None)
|
||||
and litellm.cache.supported_call_types is not None
|
||||
and (
|
||||
str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
)
|
||||
and (kwargs.get("cache", {}).get("no-store", False) is not True)
|
||||
):
|
||||
if (
|
||||
isinstance(result, litellm.ModelResponse)
|
||||
or isinstance(result, litellm.EmbeddingResponse)
|
||||
or isinstance(result, TranscriptionResponse)
|
||||
or isinstance(result, RerankResponse)
|
||||
):
|
||||
if (
|
||||
isinstance(result, EmbeddingResponse)
|
||||
and isinstance(kwargs["input"], list)
|
||||
and litellm.cache is not None
|
||||
and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
) # s3 doesn't support bulk writing. Exclude.
|
||||
):
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache_pipeline(
|
||||
result, *args, **kwargs
|
||||
)
|
||||
)
|
||||
elif isinstance(litellm.cache.cache, S3Cache):
|
||||
threading.Thread(
|
||||
target=litellm.cache.add_cache,
|
||||
args=(result,) + args,
|
||||
kwargs=kwargs,
|
||||
).start()
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache(
|
||||
result.json(), *args, **kwargs
|
||||
)
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache(result, *args, **kwargs)
|
||||
)
|
||||
## Add response to cache
|
||||
await _llm_caching_handler._async_set_cache(
|
||||
result=result,
|
||||
original_function=original_function,
|
||||
kwargs=kwargs,
|
||||
args=args,
|
||||
)
|
||||
|
||||
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
||||
print_verbose(
|
||||
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
|
||||
|
@ -1528,24 +1227,32 @@ def client(original_function):
|
|||
# REBUILD EMBEDDING CACHING
|
||||
if (
|
||||
isinstance(result, EmbeddingResponse)
|
||||
and final_embedding_cached_response is not None
|
||||
and final_embedding_cached_response.data is not None
|
||||
and _caching_handler_response.final_embedding_cached_response
|
||||
is not None
|
||||
and _caching_handler_response.final_embedding_cached_response.data
|
||||
is not None
|
||||
):
|
||||
idx = 0
|
||||
final_data_list = []
|
||||
for item in final_embedding_cached_response.data:
|
||||
for (
|
||||
item
|
||||
) in _caching_handler_response.final_embedding_cached_response.data:
|
||||
if item is None and result.data is not None:
|
||||
final_data_list.append(result.data[idx])
|
||||
idx += 1
|
||||
else:
|
||||
final_data_list.append(item)
|
||||
|
||||
final_embedding_cached_response.data = final_data_list
|
||||
final_embedding_cached_response._hidden_params["cache_hit"] = True
|
||||
final_embedding_cached_response._response_ms = (
|
||||
_caching_handler_response.final_embedding_cached_response.data = (
|
||||
final_data_list
|
||||
)
|
||||
_caching_handler_response.final_embedding_cached_response._hidden_params[
|
||||
"cache_hit"
|
||||
] = True
|
||||
_caching_handler_response.final_embedding_cached_response._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000
|
||||
return final_embedding_cached_response
|
||||
return _caching_handler_response.final_embedding_cached_response
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue