mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
feat(caching.py): enable caching on provider-specific optional params
Closes https://github.com/BerriAI/litellm/issues/5049
This commit is contained in:
parent
14d0ae6aa4
commit
8500f6d087
7 changed files with 172 additions and 74 deletions
|
@ -23,6 +23,7 @@ import litellm
|
|||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
from litellm.types.utils import all_litellm_params
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
|
@ -1838,6 +1839,7 @@ class Cache:
|
|||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stream",
|
||||
]
|
||||
embedding_only_kwargs = [
|
||||
"input",
|
||||
|
@ -1851,9 +1853,9 @@ class Cache:
|
|||
combined_kwargs = (
|
||||
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
|
||||
)
|
||||
for param in combined_kwargs:
|
||||
# ignore litellm params here
|
||||
if param in kwargs:
|
||||
litellm_param_kwargs = all_litellm_params
|
||||
for param in kwargs:
|
||||
if param in combined_kwargs:
|
||||
# check if param == model and model_group is passed in, then override model with model_group
|
||||
if param == "model":
|
||||
model_group = None
|
||||
|
@ -1897,6 +1899,17 @@ class Cache:
|
|||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
elif (
|
||||
param not in litellm_param_kwargs
|
||||
): # check if user passed in optional param - e.g. top_k
|
||||
if (
|
||||
litellm.enable_caching_on_optional_params is True
|
||||
): # feature flagged for now
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
|
||||
print_verbose(f"\nCreated cache key: {cache_key}")
|
||||
# Use hashlib to create a sha256 hash of the cache key
|
||||
hash_object = hashlib.sha256(cache_key.encode())
|
||||
|
@ -2101,9 +2114,7 @@ class Cache:
|
|||
try:
|
||||
cache_list = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
)
|
||||
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx]
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue