(refactor) get_cache_key to be under 100 LOC function (#6327)

* refactor - use helpers for name space and hashing

* use openai to get the relevant supported params

* use helpers for getting cache key

* fix test caching

* use get/set helpers for preset cache keys

* make get_cache_key under 100 LOC

* fix _get_model_param_value

* fix _get_caching_group

* fix linting error

* add unit testing for get cache key

* test_generate_streaming_content
This commit is contained in:
Ishaan Jaff 2024-10-19 15:21:11 +05:30 committed by GitHub
parent 4cbdad9fc5
commit 979e8ea526
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 477 additions and 124 deletions

View file

@ -182,7 +182,9 @@ class LLMCachingHandler:
end_time=end_time,
cache_hit=cache_hit,
)
cache_key = kwargs.get("preset_cache_key", None)
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
**kwargs
)
if (
isinstance(cached_result, BaseModel)
or isinstance(cached_result, CustomStreamWrapper)
@ -236,12 +238,7 @@ class LLMCachingHandler:
original_function=original_function
):
print_verbose("Checking Cache")
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None:
if "detail" in cached_result:
# implies an error occurred
@ -285,7 +282,9 @@ class LLMCachingHandler:
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
cache_key = kwargs.get("preset_cache_key", None)
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
**kwargs
)
if (
isinstance(cached_result, BaseModel)
or isinstance(cached_result, CustomStreamWrapper)
@ -493,10 +492,6 @@ class LLMCachingHandler:
if all(result is None for result in cached_result):
cached_result = None
else:
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
)
if litellm.cache._supports_async() is True:
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
@ -842,10 +837,16 @@ class LLMCachingHandler:
"metadata": kwargs.get("metadata", {}),
"model_info": kwargs.get("model_info", {}),
"proxy_server_request": kwargs.get("proxy_server_request", None),
"preset_cache_key": kwargs.get("preset_cache_key", None),
"stream_response": kwargs.get("stream_response", {}),
}
if litellm.cache is not None:
litellm_params["preset_cache_key"] = (
litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
)
else:
litellm_params["preset_cache_key"] = None
logging_obj.update_environment_variables(
model=model,
user=kwargs.get("user", None),