(refactor) sync caching - use LLMCachingHandler class for get_cache (#6249)

* caching - use _sync_set_cache

* add sync _sync_add_streaming_response_to_cache

* use caching class for cache storage

* fix use _sync_get_cache

* fix circular import

* use _update_litellm_logging_obj_environment

* use one helper for _process_async_embedding_cached_response

* fix _is_call_type_supported_by_cache

* fix checking cache

* fix sync get cache

* fix use _combine_cached_embedding_response_with_api_result

* fix _update_litellm_logging_obj_environment

* adjust test_redis_cache_acompletion_stream_bedrock
This commit is contained in:
Ishaan Jaff 2024-10-16 12:33:49 +05:30 committed by GitHub
parent 183bd5d873
commit 97ba4eea7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 434 additions and 294 deletions

View file

@ -773,6 +773,8 @@ def client(original_function):
call_type = original_function.__name__
if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4())
model: Optional[str] = None
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except Exception:
@ -844,116 +846,20 @@ def client(original_function):
): # allow users to control returning cached responses from the completion function
# checking cache
print_verbose("INSIDE CHECKING CACHE")
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
):
print_verbose("Checking Cache")
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
caching_handler_response: CachingHandlerResponse = (
_llm_caching_handler._sync_get_cache(
model=model or "",
original_function=original_function,
logging_obj=logging_obj,
start_time=start_time,
call_type=call_type,
kwargs=kwargs,
args=args,
)
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None:
if "detail" in cached_result:
# implies an error occurred
pass
else:
call_type = original_function.__name__
print_verbose(
f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}"
)
if call_type == CallTypes.completion.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
stream=kwargs.get("stream", False),
)
)
if caching_handler_response.cached_result is not None:
return caching_handler_response.cached_result
if kwargs.get("stream", False) is True:
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
elif call_type == CallTypes.embedding.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
response_type="embedding",
)
elif call_type == CallTypes.rerank.value and isinstance(
cached_result, dict
):
cached_result = convert_to_model_response_object(
response_object=cached_result,
response_type="rerank",
)
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
(
model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model or "",
custom_llm_provider=kwargs.get(
"custom_llm_provider", None
),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
print_verbose(
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
)
logging_obj.update_environment_variables(
model=model,
user=kwargs.get("user", None),
optional_params={},
litellm_params={
"logger_fn": kwargs.get("logger_fn", None),
"acompletion": False,
"metadata": kwargs.get("metadata", {}),
"model_info": kwargs.get("model_info", {}),
"proxy_server_request": kwargs.get(
"proxy_server_request", None
),
"preset_cache_key": kwargs.get(
"preset_cache_key", None
),
"stream_response": kwargs.get(
"stream_response", {}
),
},
input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None),
original_response=str(cached_result),
additional_args=None,
stream=kwargs.get("stream", False),
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
cache_key = kwargs.get("preset_cache_key", None)
if (
isinstance(cached_result, BaseModel)
or isinstance(cached_result, CustomStreamWrapper)
) and hasattr(cached_result, "_hidden_params"):
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
return cached_result
else:
print_verbose(
"Cache Miss! on key - {}".format(preset_cache_key)
)
# CHECK MAX TOKENS
if (
kwargs.get("max_tokens", None) is not None
@ -1245,30 +1151,13 @@ def client(original_function):
isinstance(result, EmbeddingResponse)
and _caching_handler_response.final_embedding_cached_response
is not None
and _caching_handler_response.final_embedding_cached_response.data
is not None
):
idx = 0
final_data_list = []
for (
item
) in _caching_handler_response.final_embedding_cached_response.data:
if item is None and result.data is not None:
final_data_list.append(result.data[idx])
idx += 1
else:
final_data_list.append(item)
_caching_handler_response.final_embedding_cached_response.data = (
final_data_list
return _llm_caching_handler._combine_cached_embedding_response_with_api_result(
_caching_handler_response=_caching_handler_response,
embedding_response=result,
start_time=start_time,
end_time=end_time,
)
_caching_handler_response.final_embedding_cached_response._hidden_params[
"cache_hit"
] = True
_caching_handler_response.final_embedding_cached_response._response_ms = (
end_time - start_time
).total_seconds() * 1000
return _caching_handler_response.final_embedding_cached_response
return result
except Exception as e: