forked from phoenix/litellm-mirror
(refactor) sync caching - use LLMCachingHandler
class for get_cache (#6249)
* caching - use _sync_set_cache * add sync _sync_add_streaming_response_to_cache * use caching class for cache storage * fix use _sync_get_cache * fix circular import * use _update_litellm_logging_obj_environment * use one helper for _process_async_embedding_cached_response * fix _is_call_type_supported_by_cache * fix checking cache * fix sync get cache * fix use _combine_cached_embedding_response_with_api_result * fix _update_litellm_logging_obj_environment * adjust test_redis_cache_acompletion_stream_bedrock
This commit is contained in:
parent
183bd5d873
commit
97ba4eea7d
3 changed files with 434 additions and 294 deletions
|
@ -13,7 +13,18 @@ In each method it will call the appropriate method from caching.py
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import threading
|
import threading
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -41,8 +52,10 @@ from litellm.types.utils import (
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.utils import CustomStreamWrapper
|
||||||
else:
|
else:
|
||||||
LiteLLMLoggingObj = Any
|
LiteLLMLoggingObj = Any
|
||||||
|
CustomStreamWrapper = Any
|
||||||
|
|
||||||
|
|
||||||
class CachingHandlerResponse(BaseModel):
|
class CachingHandlerResponse(BaseModel):
|
||||||
|
@ -108,6 +121,7 @@ class LLMCachingHandler:
|
||||||
args = args or ()
|
args = args or ()
|
||||||
|
|
||||||
final_embedding_cached_response: Optional[EmbeddingResponse] = None
|
final_embedding_cached_response: Optional[EmbeddingResponse] = None
|
||||||
|
embedding_all_elements_cache_hit: bool = False
|
||||||
cached_result: Optional[Any] = None
|
cached_result: Optional[Any] = None
|
||||||
if (
|
if (
|
||||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||||
|
@ -115,16 +129,10 @@ class LLMCachingHandler:
|
||||||
) and (
|
) and (
|
||||||
kwargs.get("cache", {}).get("no-cache", False) is not True
|
kwargs.get("cache", {}).get("no-cache", False) is not True
|
||||||
): # allow users to control returning cached responses from the completion function
|
): # allow users to control returning cached responses from the completion function
|
||||||
# checking cache
|
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||||
print_verbose("INSIDE CHECKING CACHE")
|
original_function=original_function
|
||||||
if (
|
|
||||||
litellm.cache is not None
|
|
||||||
and litellm.cache.supported_call_types is not None
|
|
||||||
and str(original_function.__name__)
|
|
||||||
in litellm.cache.supported_call_types
|
|
||||||
):
|
):
|
||||||
print_verbose("Checking Cache")
|
print_verbose("Checking Cache")
|
||||||
|
|
||||||
cached_result = await self._retrieve_from_cache(
|
cached_result = await self._retrieve_from_cache(
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
@ -135,42 +143,20 @@ class LLMCachingHandler:
|
||||||
print_verbose("Cache Hit!")
|
print_verbose("Cache Hit!")
|
||||||
cache_hit = True
|
cache_hit = True
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
(
|
model, _, _, _ = litellm.get_llm_provider(
|
||||||
model,
|
|
||||||
custom_llm_provider,
|
|
||||||
dynamic_api_key,
|
|
||||||
api_base,
|
|
||||||
) = litellm.get_llm_provider(
|
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||||
api_base=kwargs.get("api_base", None),
|
api_base=kwargs.get("api_base", None),
|
||||||
api_key=kwargs.get("api_key", None),
|
api_key=kwargs.get("api_key", None),
|
||||||
)
|
)
|
||||||
print_verbose(
|
self._update_litellm_logging_obj_environment(
|
||||||
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
|
logging_obj=logging_obj,
|
||||||
)
|
|
||||||
logging_obj.update_environment_variables(
|
|
||||||
model=model,
|
model=model,
|
||||||
user=kwargs.get("user", None),
|
kwargs=kwargs,
|
||||||
optional_params={},
|
cached_result=cached_result,
|
||||||
litellm_params={
|
is_async=True,
|
||||||
"logger_fn": kwargs.get("logger_fn", None),
|
|
||||||
"acompletion": True,
|
|
||||||
"metadata": kwargs.get("metadata", {}),
|
|
||||||
"model_info": kwargs.get("model_info", {}),
|
|
||||||
"proxy_server_request": kwargs.get(
|
|
||||||
"proxy_server_request", None
|
|
||||||
),
|
|
||||||
"preset_cache_key": kwargs.get("preset_cache_key", None),
|
|
||||||
"stream_response": kwargs.get("stream_response", {}),
|
|
||||||
"api_base": kwargs.get("api_base", ""),
|
|
||||||
},
|
|
||||||
input=kwargs.get("messages", ""),
|
|
||||||
api_key=kwargs.get("api_key", None),
|
|
||||||
original_response=str(cached_result),
|
|
||||||
additional_args=None,
|
|
||||||
stream=kwargs.get("stream", False),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
call_type = original_function.__name__
|
call_type = original_function.__name__
|
||||||
|
|
||||||
cached_result = self._convert_cached_result_to_model_response(
|
cached_result = self._convert_cached_result_to_model_response(
|
||||||
|
@ -184,15 +170,13 @@ class LLMCachingHandler:
|
||||||
)
|
)
|
||||||
if kwargs.get("stream", False) is False:
|
if kwargs.get("stream", False) is False:
|
||||||
# LOG SUCCESS
|
# LOG SUCCESS
|
||||||
asyncio.create_task(
|
self._async_log_cache_hit_on_callbacks(
|
||||||
logging_obj.async_success_handler(
|
logging_obj=logging_obj,
|
||||||
cached_result, start_time, end_time, cache_hit
|
cached_result=cached_result,
|
||||||
)
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
cache_hit=cache_hit,
|
||||||
)
|
)
|
||||||
threading.Thread(
|
|
||||||
target=logging_obj.success_handler,
|
|
||||||
args=(cached_result, start_time, end_time, cache_hit),
|
|
||||||
).start()
|
|
||||||
cache_key = kwargs.get("preset_cache_key", None)
|
cache_key = kwargs.get("preset_cache_key", None)
|
||||||
if (
|
if (
|
||||||
isinstance(cached_result, BaseModel)
|
isinstance(cached_result, BaseModel)
|
||||||
|
@ -209,101 +193,261 @@ class LLMCachingHandler:
|
||||||
litellm.cache.cache, S3Cache
|
litellm.cache.cache, S3Cache
|
||||||
) # s3 doesn't support bulk writing. Exclude.
|
) # s3 doesn't support bulk writing. Exclude.
|
||||||
):
|
):
|
||||||
remaining_list = []
|
(
|
||||||
non_null_list = []
|
final_embedding_cached_response,
|
||||||
for idx, cr in enumerate(cached_result):
|
embedding_all_elements_cache_hit,
|
||||||
if cr is None:
|
) = self._process_async_embedding_cached_response(
|
||||||
remaining_list.append(kwargs["input"][idx])
|
final_embedding_cached_response=final_embedding_cached_response,
|
||||||
else:
|
cached_result=cached_result,
|
||||||
non_null_list.append((idx, cr))
|
kwargs=kwargs,
|
||||||
original_kwargs_input = kwargs["input"]
|
logging_obj=logging_obj,
|
||||||
kwargs["input"] = remaining_list
|
start_time=start_time,
|
||||||
if len(non_null_list) > 0:
|
model=model,
|
||||||
print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}")
|
)
|
||||||
final_embedding_cached_response = EmbeddingResponse(
|
return CachingHandlerResponse(
|
||||||
model=kwargs.get("model"),
|
final_embedding_cached_response=final_embedding_cached_response,
|
||||||
data=[None] * len(original_kwargs_input),
|
embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
|
||||||
)
|
)
|
||||||
final_embedding_cached_response._hidden_params["cache_hit"] = (
|
|
||||||
True
|
|
||||||
)
|
|
||||||
|
|
||||||
for val in non_null_list:
|
|
||||||
idx, cr = val # (idx, cr) tuple
|
|
||||||
if cr is not None:
|
|
||||||
final_embedding_cached_response.data[idx] = Embedding(
|
|
||||||
embedding=cr["embedding"],
|
|
||||||
index=idx,
|
|
||||||
object="embedding",
|
|
||||||
)
|
|
||||||
if len(remaining_list) == 0:
|
|
||||||
# LOG SUCCESS
|
|
||||||
cache_hit = True
|
|
||||||
end_time = datetime.datetime.now()
|
|
||||||
(
|
|
||||||
model,
|
|
||||||
custom_llm_provider,
|
|
||||||
dynamic_api_key,
|
|
||||||
api_base,
|
|
||||||
) = litellm.get_llm_provider(
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
|
||||||
api_base=kwargs.get("api_base", None),
|
|
||||||
api_key=kwargs.get("api_key", None),
|
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
|
|
||||||
)
|
|
||||||
logging_obj.update_environment_variables(
|
|
||||||
model=model,
|
|
||||||
user=kwargs.get("user", None),
|
|
||||||
optional_params={},
|
|
||||||
litellm_params={
|
|
||||||
"logger_fn": kwargs.get("logger_fn", None),
|
|
||||||
"acompletion": True,
|
|
||||||
"metadata": kwargs.get("metadata", {}),
|
|
||||||
"model_info": kwargs.get("model_info", {}),
|
|
||||||
"proxy_server_request": kwargs.get(
|
|
||||||
"proxy_server_request", None
|
|
||||||
),
|
|
||||||
"preset_cache_key": kwargs.get(
|
|
||||||
"preset_cache_key", None
|
|
||||||
),
|
|
||||||
"stream_response": kwargs.get("stream_response", {}),
|
|
||||||
"api_base": "",
|
|
||||||
},
|
|
||||||
input=kwargs.get("messages", ""),
|
|
||||||
api_key=kwargs.get("api_key", None),
|
|
||||||
original_response=str(final_embedding_cached_response),
|
|
||||||
additional_args=None,
|
|
||||||
stream=kwargs.get("stream", False),
|
|
||||||
)
|
|
||||||
asyncio.create_task(
|
|
||||||
logging_obj.async_success_handler(
|
|
||||||
final_embedding_cached_response,
|
|
||||||
start_time,
|
|
||||||
end_time,
|
|
||||||
cache_hit,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
threading.Thread(
|
|
||||||
target=logging_obj.success_handler,
|
|
||||||
args=(
|
|
||||||
final_embedding_cached_response,
|
|
||||||
start_time,
|
|
||||||
end_time,
|
|
||||||
cache_hit,
|
|
||||||
),
|
|
||||||
).start()
|
|
||||||
return CachingHandlerResponse(
|
|
||||||
final_embedding_cached_response=final_embedding_cached_response,
|
|
||||||
embedding_all_elements_cache_hit=True,
|
|
||||||
)
|
|
||||||
return CachingHandlerResponse(
|
return CachingHandlerResponse(
|
||||||
cached_result=cached_result,
|
cached_result=cached_result,
|
||||||
final_embedding_cached_response=final_embedding_cached_response,
|
final_embedding_cached_response=final_embedding_cached_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _sync_get_cache(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
original_function: Callable,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
start_time: datetime.datetime,
|
||||||
|
call_type: str,
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
args: Optional[Tuple[Any, ...]] = None,
|
||||||
|
) -> CachingHandlerResponse:
|
||||||
|
from litellm.utils import CustomStreamWrapper
|
||||||
|
|
||||||
|
args = args or ()
|
||||||
|
cached_result: Optional[Any] = None
|
||||||
|
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||||
|
original_function=original_function
|
||||||
|
):
|
||||||
|
print_verbose("Checking Cache")
|
||||||
|
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||||
|
kwargs["preset_cache_key"] = (
|
||||||
|
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||||
|
)
|
||||||
|
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||||
|
|
||||||
|
if cached_result is not None:
|
||||||
|
if "detail" in cached_result:
|
||||||
|
# implies an error occurred
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
call_type = original_function.__name__
|
||||||
|
|
||||||
|
cached_result = self._convert_cached_result_to_model_response(
|
||||||
|
cached_result=cached_result,
|
||||||
|
call_type=call_type,
|
||||||
|
kwargs=kwargs,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||||
|
args=args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# LOG SUCCESS
|
||||||
|
cache_hit = True
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
(
|
||||||
|
model,
|
||||||
|
custom_llm_provider,
|
||||||
|
dynamic_api_key,
|
||||||
|
api_base,
|
||||||
|
) = litellm.get_llm_provider(
|
||||||
|
model=model or "",
|
||||||
|
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||||
|
api_base=kwargs.get("api_base", None),
|
||||||
|
api_key=kwargs.get("api_key", None),
|
||||||
|
)
|
||||||
|
self._update_litellm_logging_obj_environment(
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
model=model,
|
||||||
|
kwargs=kwargs,
|
||||||
|
cached_result=cached_result,
|
||||||
|
is_async=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
threading.Thread(
|
||||||
|
target=logging_obj.success_handler,
|
||||||
|
args=(cached_result, start_time, end_time, cache_hit),
|
||||||
|
).start()
|
||||||
|
cache_key = kwargs.get("preset_cache_key", None)
|
||||||
|
if (
|
||||||
|
isinstance(cached_result, BaseModel)
|
||||||
|
or isinstance(cached_result, CustomStreamWrapper)
|
||||||
|
) and hasattr(cached_result, "_hidden_params"):
|
||||||
|
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
|
||||||
|
return CachingHandlerResponse(cached_result=cached_result)
|
||||||
|
return CachingHandlerResponse(cached_result=cached_result)
|
||||||
|
|
||||||
|
def _process_async_embedding_cached_response(
|
||||||
|
self,
|
||||||
|
final_embedding_cached_response: Optional[EmbeddingResponse],
|
||||||
|
cached_result: List[Optional[Dict[str, Any]]],
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
start_time: datetime.datetime,
|
||||||
|
model: str,
|
||||||
|
) -> Tuple[Optional[EmbeddingResponse], bool]:
|
||||||
|
"""
|
||||||
|
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
||||||
|
|
||||||
|
For embedding responses, there can be a cache hit for some of the inputs in the list and a cache miss for others
|
||||||
|
This function processes the cached embedding responses and returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
||||||
|
|
||||||
|
Args:
|
||||||
|
final_embedding_cached_response: Optional[EmbeddingResponse]:
|
||||||
|
cached_result: List[Optional[Dict[str, Any]]]:
|
||||||
|
kwargs: Dict[str, Any]:
|
||||||
|
logging_obj: LiteLLMLoggingObj:
|
||||||
|
start_time: datetime.datetime:
|
||||||
|
model: str:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Optional[EmbeddingResponse], bool]:
|
||||||
|
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
embedding_all_elements_cache_hit: bool = False
|
||||||
|
remaining_list = []
|
||||||
|
non_null_list = []
|
||||||
|
for idx, cr in enumerate(cached_result):
|
||||||
|
if cr is None:
|
||||||
|
remaining_list.append(kwargs["input"][idx])
|
||||||
|
else:
|
||||||
|
non_null_list.append((idx, cr))
|
||||||
|
original_kwargs_input = kwargs["input"]
|
||||||
|
kwargs["input"] = remaining_list
|
||||||
|
if len(non_null_list) > 0:
|
||||||
|
print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}")
|
||||||
|
final_embedding_cached_response = EmbeddingResponse(
|
||||||
|
model=kwargs.get("model"),
|
||||||
|
data=[None] * len(original_kwargs_input),
|
||||||
|
)
|
||||||
|
final_embedding_cached_response._hidden_params["cache_hit"] = True
|
||||||
|
|
||||||
|
for val in non_null_list:
|
||||||
|
idx, cr = val # (idx, cr) tuple
|
||||||
|
if cr is not None:
|
||||||
|
final_embedding_cached_response.data[idx] = Embedding(
|
||||||
|
embedding=cr["embedding"],
|
||||||
|
index=idx,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
if len(remaining_list) == 0:
|
||||||
|
# LOG SUCCESS
|
||||||
|
cache_hit = True
|
||||||
|
embedding_all_elements_cache_hit = True
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
(
|
||||||
|
model,
|
||||||
|
custom_llm_provider,
|
||||||
|
dynamic_api_key,
|
||||||
|
api_base,
|
||||||
|
) = litellm.get_llm_provider(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||||
|
api_base=kwargs.get("api_base", None),
|
||||||
|
api_key=kwargs.get("api_key", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._update_litellm_logging_obj_environment(
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
model=model,
|
||||||
|
kwargs=kwargs,
|
||||||
|
cached_result=final_embedding_cached_response,
|
||||||
|
is_async=True,
|
||||||
|
is_embedding=True,
|
||||||
|
)
|
||||||
|
self._async_log_cache_hit_on_callbacks(
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
cached_result=final_embedding_cached_response,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
cache_hit=cache_hit,
|
||||||
|
)
|
||||||
|
return final_embedding_cached_response, embedding_all_elements_cache_hit
|
||||||
|
return final_embedding_cached_response, embedding_all_elements_cache_hit
|
||||||
|
|
||||||
|
def _combine_cached_embedding_response_with_api_result(
|
||||||
|
self,
|
||||||
|
_caching_handler_response: CachingHandlerResponse,
|
||||||
|
embedding_response: EmbeddingResponse,
|
||||||
|
start_time: datetime.datetime,
|
||||||
|
end_time: datetime.datetime,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""
|
||||||
|
Combines the cached embedding response with the API EmbeddingResponse
|
||||||
|
|
||||||
|
For caching there can be a cache hit for some of the inputs in the list and a cache miss for others
|
||||||
|
This function combines the cached embedding response with the API EmbeddingResponse
|
||||||
|
|
||||||
|
Args:
|
||||||
|
caching_handler_response: CachingHandlerResponse:
|
||||||
|
embedding_response: EmbeddingResponse:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResponse:
|
||||||
|
"""
|
||||||
|
if _caching_handler_response.final_embedding_cached_response is None:
|
||||||
|
return embedding_response
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
final_data_list = []
|
||||||
|
for item in _caching_handler_response.final_embedding_cached_response.data:
|
||||||
|
if item is None and embedding_response.data is not None:
|
||||||
|
final_data_list.append(embedding_response.data[idx])
|
||||||
|
idx += 1
|
||||||
|
else:
|
||||||
|
final_data_list.append(item)
|
||||||
|
|
||||||
|
_caching_handler_response.final_embedding_cached_response.data = final_data_list
|
||||||
|
_caching_handler_response.final_embedding_cached_response._hidden_params[
|
||||||
|
"cache_hit"
|
||||||
|
] = True
|
||||||
|
_caching_handler_response.final_embedding_cached_response._response_ms = (
|
||||||
|
end_time - start_time
|
||||||
|
).total_seconds() * 1000
|
||||||
|
return _caching_handler_response.final_embedding_cached_response
|
||||||
|
|
||||||
|
def _async_log_cache_hit_on_callbacks(
|
||||||
|
self,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
cached_result: Any,
|
||||||
|
start_time: datetime.datetime,
|
||||||
|
end_time: datetime.datetime,
|
||||||
|
cache_hit: bool,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper function to log the success of a cached result on callbacks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logging_obj (LiteLLMLoggingObj): The logging object.
|
||||||
|
cached_result: The cached result.
|
||||||
|
start_time (datetime): The start time of the operation.
|
||||||
|
end_time (datetime): The end time of the operation.
|
||||||
|
cache_hit (bool): Whether it was a cache hit.
|
||||||
|
"""
|
||||||
|
asyncio.create_task(
|
||||||
|
logging_obj.async_success_handler(
|
||||||
|
cached_result, start_time, end_time, cache_hit
|
||||||
|
)
|
||||||
|
)
|
||||||
|
threading.Thread(
|
||||||
|
target=logging_obj.success_handler,
|
||||||
|
args=(cached_result, start_time, end_time, cache_hit),
|
||||||
|
).start()
|
||||||
|
|
||||||
async def _retrieve_from_cache(
|
async def _retrieve_from_cache(
|
||||||
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
|
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
|
||||||
) -> Optional[Any]:
|
) -> Optional[Any]:
|
||||||
|
@ -385,57 +529,60 @@ class LLMCachingHandler:
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
|
convert_to_streaming_response,
|
||||||
convert_to_streaming_response_async,
|
convert_to_streaming_response_async,
|
||||||
)
|
)
|
||||||
|
|
||||||
if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict):
|
if (
|
||||||
|
call_type == CallTypes.acompletion.value
|
||||||
|
or call_type == CallTypes.completion.value
|
||||||
|
) and isinstance(cached_result, dict):
|
||||||
if kwargs.get("stream", False) is True:
|
if kwargs.get("stream", False) is True:
|
||||||
cached_result = convert_to_streaming_response_async(
|
cached_result = self._convert_cached_stream_response(
|
||||||
response_object=cached_result,
|
cached_result=cached_result,
|
||||||
)
|
call_type=call_type,
|
||||||
cached_result = CustomStreamWrapper(
|
|
||||||
completion_stream=cached_result,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="cached_response",
|
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cached_result = convert_to_model_response_object(
|
cached_result = convert_to_model_response_object(
|
||||||
response_object=cached_result,
|
response_object=cached_result,
|
||||||
model_response_object=ModelResponse(),
|
model_response_object=ModelResponse(),
|
||||||
)
|
)
|
||||||
if call_type == CallTypes.atext_completion.value and isinstance(
|
if (
|
||||||
cached_result, dict
|
call_type == CallTypes.atext_completion.value
|
||||||
):
|
or call_type == CallTypes.text_completion.value
|
||||||
|
) and isinstance(cached_result, dict):
|
||||||
if kwargs.get("stream", False) is True:
|
if kwargs.get("stream", False) is True:
|
||||||
cached_result = convert_to_streaming_response_async(
|
cached_result = self._convert_cached_stream_response(
|
||||||
response_object=cached_result,
|
cached_result=cached_result,
|
||||||
)
|
call_type=call_type,
|
||||||
cached_result = CustomStreamWrapper(
|
|
||||||
completion_stream=cached_result,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="cached_response",
|
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cached_result = TextCompletionResponse(**cached_result)
|
cached_result = TextCompletionResponse(**cached_result)
|
||||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
elif (
|
||||||
cached_result, dict
|
call_type == CallTypes.aembedding.value
|
||||||
):
|
or call_type == CallTypes.embedding.value
|
||||||
|
) and isinstance(cached_result, dict):
|
||||||
cached_result = convert_to_model_response_object(
|
cached_result = convert_to_model_response_object(
|
||||||
response_object=cached_result,
|
response_object=cached_result,
|
||||||
model_response_object=EmbeddingResponse(),
|
model_response_object=EmbeddingResponse(),
|
||||||
response_type="embedding",
|
response_type="embedding",
|
||||||
)
|
)
|
||||||
elif call_type == CallTypes.arerank.value and isinstance(cached_result, dict):
|
elif (
|
||||||
|
call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
|
||||||
|
) and isinstance(cached_result, dict):
|
||||||
cached_result = convert_to_model_response_object(
|
cached_result = convert_to_model_response_object(
|
||||||
response_object=cached_result,
|
response_object=cached_result,
|
||||||
model_response_object=None,
|
model_response_object=None,
|
||||||
response_type="rerank",
|
response_type="rerank",
|
||||||
)
|
)
|
||||||
elif call_type == CallTypes.atranscription.value and isinstance(
|
elif (
|
||||||
cached_result, dict
|
call_type == CallTypes.atranscription.value
|
||||||
):
|
or call_type == CallTypes.transcription.value
|
||||||
|
) and isinstance(cached_result, dict):
|
||||||
hidden_params = {
|
hidden_params = {
|
||||||
"model": "whisper-1",
|
"model": "whisper-1",
|
||||||
"custom_llm_provider": custom_llm_provider,
|
"custom_llm_provider": custom_llm_provider,
|
||||||
|
@ -449,6 +596,38 @@ class LLMCachingHandler:
|
||||||
)
|
)
|
||||||
return cached_result
|
return cached_result
|
||||||
|
|
||||||
|
def _convert_cached_stream_response(
|
||||||
|
self,
|
||||||
|
cached_result: Any,
|
||||||
|
call_type: str,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
model: str,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
from litellm.utils import (
|
||||||
|
CustomStreamWrapper,
|
||||||
|
convert_to_streaming_response,
|
||||||
|
convert_to_streaming_response_async,
|
||||||
|
)
|
||||||
|
|
||||||
|
_stream_cached_result: Union[AsyncGenerator, Generator]
|
||||||
|
if (
|
||||||
|
call_type == CallTypes.acompletion.value
|
||||||
|
or call_type == CallTypes.atext_completion.value
|
||||||
|
):
|
||||||
|
_stream_cached_result = convert_to_streaming_response_async(
|
||||||
|
response_object=cached_result,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_stream_cached_result = convert_to_streaming_response(
|
||||||
|
response_object=cached_result,
|
||||||
|
)
|
||||||
|
return CustomStreamWrapper(
|
||||||
|
completion_stream=_stream_cached_result,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
async def _async_set_cache(
|
async def _async_set_cache(
|
||||||
self,
|
self,
|
||||||
result: Any,
|
result: Any,
|
||||||
|
@ -545,6 +724,28 @@ class LLMCachingHandler:
|
||||||
and (kwargs.get("cache", {}).get("no-store", False) is not True)
|
and (kwargs.get("cache", {}).get("no-store", False) is not True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _is_call_type_supported_by_cache(
|
||||||
|
self,
|
||||||
|
original_function: Callable,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Helper function to determine if the call type is supported by the cache.
|
||||||
|
|
||||||
|
call types are acompletion, aembedding, atext_completion, atranscription, arerank
|
||||||
|
|
||||||
|
Defined on `litellm.types.utils.CallTypes`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the call type is supported by the cache, False otherwise.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
litellm.cache is not None
|
||||||
|
and litellm.cache.supported_call_types is not None
|
||||||
|
and str(original_function.__name__) in litellm.cache.supported_call_types
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
|
async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
|
||||||
"""
|
"""
|
||||||
Internal method to add the streaming response to the cache
|
Internal method to add the streaming response to the cache
|
||||||
|
@ -594,3 +795,53 @@ class LLMCachingHandler:
|
||||||
result=complete_streaming_response,
|
result=complete_streaming_response,
|
||||||
kwargs=self.request_kwargs,
|
kwargs=self.request_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _update_litellm_logging_obj_environment(
|
||||||
|
self,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
model: str,
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
cached_result: Any,
|
||||||
|
is_async: bool,
|
||||||
|
is_embedding: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper function to update the LiteLLMLoggingObj environment variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logging_obj (LiteLLMLoggingObj): The logging object to update.
|
||||||
|
model (str): The model being used.
|
||||||
|
kwargs (Dict[str, Any]): The keyword arguments from the original function call.
|
||||||
|
cached_result (Any): The cached result to log.
|
||||||
|
is_async (bool): Whether the call is asynchronous or not.
|
||||||
|
is_embedding (bool): Whether the call is for embeddings or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
litellm_params = {
|
||||||
|
"logger_fn": kwargs.get("logger_fn", None),
|
||||||
|
"acompletion": is_async,
|
||||||
|
"api_base": kwargs.get("api_base", ""),
|
||||||
|
"metadata": kwargs.get("metadata", {}),
|
||||||
|
"model_info": kwargs.get("model_info", {}),
|
||||||
|
"proxy_server_request": kwargs.get("proxy_server_request", None),
|
||||||
|
"preset_cache_key": kwargs.get("preset_cache_key", None),
|
||||||
|
"stream_response": kwargs.get("stream_response", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
logging_obj.update_environment_variables(
|
||||||
|
model=model,
|
||||||
|
user=kwargs.get("user", None),
|
||||||
|
optional_params={},
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
input=(
|
||||||
|
kwargs.get("messages", "")
|
||||||
|
if not is_embedding
|
||||||
|
else kwargs.get("input", "")
|
||||||
|
),
|
||||||
|
api_key=kwargs.get("api_key", None),
|
||||||
|
original_response=str(cached_result),
|
||||||
|
additional_args=None,
|
||||||
|
stream=kwargs.get("stream", False),
|
||||||
|
)
|
||||||
|
|
149
litellm/utils.py
149
litellm/utils.py
|
@ -773,6 +773,8 @@ def client(original_function):
|
||||||
call_type = original_function.__name__
|
call_type = original_function.__name__
|
||||||
if "litellm_call_id" not in kwargs:
|
if "litellm_call_id" not in kwargs:
|
||||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
model: Optional[str] = None
|
||||||
try:
|
try:
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -844,116 +846,20 @@ def client(original_function):
|
||||||
): # allow users to control returning cached responses from the completion function
|
): # allow users to control returning cached responses from the completion function
|
||||||
# checking cache
|
# checking cache
|
||||||
print_verbose("INSIDE CHECKING CACHE")
|
print_verbose("INSIDE CHECKING CACHE")
|
||||||
if (
|
caching_handler_response: CachingHandlerResponse = (
|
||||||
litellm.cache is not None
|
_llm_caching_handler._sync_get_cache(
|
||||||
and litellm.cache.supported_call_types is not None
|
model=model or "",
|
||||||
and str(original_function.__name__)
|
original_function=original_function,
|
||||||
in litellm.cache.supported_call_types
|
logging_obj=logging_obj,
|
||||||
):
|
start_time=start_time,
|
||||||
print_verbose("Checking Cache")
|
call_type=call_type,
|
||||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
kwargs=kwargs,
|
||||||
kwargs["preset_cache_key"] = (
|
args=args,
|
||||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
|
||||||
)
|
)
|
||||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
)
|
||||||
if cached_result is not None:
|
if caching_handler_response.cached_result is not None:
|
||||||
if "detail" in cached_result:
|
return caching_handler_response.cached_result
|
||||||
# implies an error occurred
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
call_type = original_function.__name__
|
|
||||||
print_verbose(
|
|
||||||
f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}"
|
|
||||||
)
|
|
||||||
if call_type == CallTypes.completion.value and isinstance(
|
|
||||||
cached_result, dict
|
|
||||||
):
|
|
||||||
cached_result = convert_to_model_response_object(
|
|
||||||
response_object=cached_result,
|
|
||||||
model_response_object=ModelResponse(),
|
|
||||||
stream=kwargs.get("stream", False),
|
|
||||||
)
|
|
||||||
|
|
||||||
if kwargs.get("stream", False) is True:
|
|
||||||
cached_result = CustomStreamWrapper(
|
|
||||||
completion_stream=cached_result,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="cached_response",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
elif call_type == CallTypes.embedding.value and isinstance(
|
|
||||||
cached_result, dict
|
|
||||||
):
|
|
||||||
cached_result = convert_to_model_response_object(
|
|
||||||
response_object=cached_result,
|
|
||||||
response_type="embedding",
|
|
||||||
)
|
|
||||||
elif call_type == CallTypes.rerank.value and isinstance(
|
|
||||||
cached_result, dict
|
|
||||||
):
|
|
||||||
cached_result = convert_to_model_response_object(
|
|
||||||
response_object=cached_result,
|
|
||||||
response_type="rerank",
|
|
||||||
)
|
|
||||||
# LOG SUCCESS
|
|
||||||
cache_hit = True
|
|
||||||
end_time = datetime.datetime.now()
|
|
||||||
(
|
|
||||||
model,
|
|
||||||
custom_llm_provider,
|
|
||||||
dynamic_api_key,
|
|
||||||
api_base,
|
|
||||||
) = litellm.get_llm_provider(
|
|
||||||
model=model or "",
|
|
||||||
custom_llm_provider=kwargs.get(
|
|
||||||
"custom_llm_provider", None
|
|
||||||
),
|
|
||||||
api_base=kwargs.get("api_base", None),
|
|
||||||
api_key=kwargs.get("api_key", None),
|
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
|
|
||||||
)
|
|
||||||
logging_obj.update_environment_variables(
|
|
||||||
model=model,
|
|
||||||
user=kwargs.get("user", None),
|
|
||||||
optional_params={},
|
|
||||||
litellm_params={
|
|
||||||
"logger_fn": kwargs.get("logger_fn", None),
|
|
||||||
"acompletion": False,
|
|
||||||
"metadata": kwargs.get("metadata", {}),
|
|
||||||
"model_info": kwargs.get("model_info", {}),
|
|
||||||
"proxy_server_request": kwargs.get(
|
|
||||||
"proxy_server_request", None
|
|
||||||
),
|
|
||||||
"preset_cache_key": kwargs.get(
|
|
||||||
"preset_cache_key", None
|
|
||||||
),
|
|
||||||
"stream_response": kwargs.get(
|
|
||||||
"stream_response", {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
input=kwargs.get("messages", ""),
|
|
||||||
api_key=kwargs.get("api_key", None),
|
|
||||||
original_response=str(cached_result),
|
|
||||||
additional_args=None,
|
|
||||||
stream=kwargs.get("stream", False),
|
|
||||||
)
|
|
||||||
threading.Thread(
|
|
||||||
target=logging_obj.success_handler,
|
|
||||||
args=(cached_result, start_time, end_time, cache_hit),
|
|
||||||
).start()
|
|
||||||
cache_key = kwargs.get("preset_cache_key", None)
|
|
||||||
if (
|
|
||||||
isinstance(cached_result, BaseModel)
|
|
||||||
or isinstance(cached_result, CustomStreamWrapper)
|
|
||||||
) and hasattr(cached_result, "_hidden_params"):
|
|
||||||
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
|
|
||||||
return cached_result
|
|
||||||
else:
|
|
||||||
print_verbose(
|
|
||||||
"Cache Miss! on key - {}".format(preset_cache_key)
|
|
||||||
)
|
|
||||||
# CHECK MAX TOKENS
|
# CHECK MAX TOKENS
|
||||||
if (
|
if (
|
||||||
kwargs.get("max_tokens", None) is not None
|
kwargs.get("max_tokens", None) is not None
|
||||||
|
@ -1245,30 +1151,13 @@ def client(original_function):
|
||||||
isinstance(result, EmbeddingResponse)
|
isinstance(result, EmbeddingResponse)
|
||||||
and _caching_handler_response.final_embedding_cached_response
|
and _caching_handler_response.final_embedding_cached_response
|
||||||
is not None
|
is not None
|
||||||
and _caching_handler_response.final_embedding_cached_response.data
|
|
||||||
is not None
|
|
||||||
):
|
):
|
||||||
idx = 0
|
return _llm_caching_handler._combine_cached_embedding_response_with_api_result(
|
||||||
final_data_list = []
|
_caching_handler_response=_caching_handler_response,
|
||||||
for (
|
embedding_response=result,
|
||||||
item
|
start_time=start_time,
|
||||||
) in _caching_handler_response.final_embedding_cached_response.data:
|
end_time=end_time,
|
||||||
if item is None and result.data is not None:
|
|
||||||
final_data_list.append(result.data[idx])
|
|
||||||
idx += 1
|
|
||||||
else:
|
|
||||||
final_data_list.append(item)
|
|
||||||
|
|
||||||
_caching_handler_response.final_embedding_cached_response.data = (
|
|
||||||
final_data_list
|
|
||||||
)
|
)
|
||||||
_caching_handler_response.final_embedding_cached_response._hidden_params[
|
|
||||||
"cache_hit"
|
|
||||||
] = True
|
|
||||||
_caching_handler_response.final_embedding_cached_response._response_ms = (
|
|
||||||
end_time - start_time
|
|
||||||
).total_seconds() * 1000
|
|
||||||
return _caching_handler_response.final_embedding_cached_response
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1067,7 +1067,7 @@ async def test_redis_cache_acompletion_stream_bedrock():
|
||||||
response_1_content += chunk.choices[0].delta.content or ""
|
response_1_content += chunk.choices[0].delta.content or ""
|
||||||
print(response_1_content)
|
print(response_1_content)
|
||||||
|
|
||||||
time.sleep(0.5)
|
await asyncio.sleep(1)
|
||||||
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
||||||
|
|
||||||
response2 = await litellm.acompletion(
|
response2 = await litellm.acompletion(
|
||||||
|
@ -1082,8 +1082,8 @@ async def test_redis_cache_acompletion_stream_bedrock():
|
||||||
response_2_content += chunk.choices[0].delta.content or ""
|
response_2_content += chunk.choices[0].delta.content or ""
|
||||||
print(response_2_content)
|
print(response_2_content)
|
||||||
|
|
||||||
print("\nresponse 1", response_1_content)
|
print("\nfinal response 1", response_1_content)
|
||||||
print("\nresponse 2", response_2_content)
|
print("\nfinal response 2", response_2_content)
|
||||||
assert (
|
assert (
|
||||||
response_1_content == response_2_content
|
response_1_content == response_2_content
|
||||||
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue