feat(caching.py): enable caching on provider-specific optional params

Closes https://github.com/BerriAI/litellm/issues/5049
This commit is contained in:
Krrish Dholakia 2024-08-05 11:18:59 -07:00
parent 14d0ae6aa4
commit 8500f6d087
7 changed files with 172 additions and 74 deletions

View file

@ -23,6 +23,7 @@ import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from litellm.types.utils import all_litellm_params
def print_verbose(print_statement):
@ -1838,6 +1839,7 @@ class Cache:
"seed",
"tools",
"tool_choice",
"stream",
]
embedding_only_kwargs = [
"input",
@ -1851,9 +1853,9 @@ class Cache:
combined_kwargs = (
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
)
for param in combined_kwargs:
# ignore litellm params here
if param in kwargs:
litellm_param_kwargs = all_litellm_params
for param in kwargs:
if param in combined_kwargs:
# check if param == model and model_group is passed in, then override model with model_group
if param == "model":
model_group = None
@ -1897,6 +1899,17 @@ class Cache:
continue # ignore None params
param_value = kwargs[param]
cache_key += f"{str(param)}: {str(param_value)}"
elif (
param not in litellm_param_kwargs
): # check if user passed in optional param - e.g. top_k
if (
litellm.enable_caching_on_optional_params is True
): # feature flagged for now
if kwargs[param] is None:
continue # ignore None params
param_value = kwargs[param]
cache_key += f"{str(param)}: {str(param_value)}"
print_verbose(f"\nCreated cache key: {cache_key}")
# Use hashlib to create a sha256 hash of the cache key
hash_object = hashlib.sha256(cache_key.encode())
@ -2101,9 +2114,7 @@ class Cache:
try:
cache_list = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx]
cache_key, cached_data, kwargs = self._add_cache_logic(