(refactor) caching - use _sync_set_cache (#6224)

* caching - use _sync_set_cache

* add sync _sync_add_streaming_response_to_cache

* use caching class for cache storage
This commit is contained in:
Ishaan Jaff 2024-10-16 10:38:07 +05:30 committed by GitHub
parent b8d4973661
commit da6a7c3a55
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 89 additions and 28 deletions

View file

@ -67,6 +67,7 @@ class LLMCachingHandler:
start_time: datetime.datetime,
):
self.async_streaming_chunks: List[ModelResponse] = []
self.sync_streaming_chunks: List[ModelResponse] = []
self.request_kwargs = request_kwargs
self.original_function = original_function
self.start_time = start_time
@ -470,12 +471,11 @@ class LLMCachingHandler:
None
"""
args = args or ()
if litellm.cache is None:
return
# [OPTIONAL] ADD TO CACHE
if (
(litellm.cache is not None)
and litellm.cache.supported_call_types is not None
and (str(original_function.__name__) in litellm.cache.supported_call_types)
and (kwargs.get("cache", {}).get("no-store", False) is not True)
if self._should_store_result_in_cache(
original_function=original_function, kwargs=kwargs
):
if (
isinstance(result, litellm.ModelResponse)
@ -509,6 +509,42 @@ class LLMCachingHandler:
litellm.cache.async_add_cache(result, *args, **kwargs)
)
def _sync_set_cache(
self,
result: Any,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
):
"""
Sync internal method to add the result to the cache
"""
if litellm.cache is None:
return
args = args or ()
if self._should_store_result_in_cache(
original_function=self.original_function, kwargs=kwargs
):
litellm.cache.add_cache(result, *args, **kwargs)
return
def _should_store_result_in_cache(
self, original_function: Callable, kwargs: Dict[str, Any]
) -> bool:
"""
Helper function to determine if the result should be stored in the cache.
Returns:
bool: True if the result should be stored in the cache, False otherwise.
"""
return (
(litellm.cache is not None)
and litellm.cache.supported_call_types is not None
and (str(original_function.__name__) in litellm.cache.supported_call_types)
and (kwargs.get("cache", {}).get("no-store", False) is not True)
)
async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
"""
Internal method to add the streaming response to the cache
@ -536,3 +572,25 @@ class LLMCachingHandler:
original_function=self.original_function,
kwargs=self.request_kwargs,
)
def _sync_add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
"""
Sync internal method to add the streaming response to the cache
"""
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=processed_chunk,
start_time=self.start_time,
end_time=datetime.datetime.now(),
request_kwargs=self.request_kwargs,
streaming_chunks=self.sync_streaming_chunks,
is_async=False,
)
# if a complete_streaming_response is assembled, add it to the cache
if complete_streaming_response is not None:
self._sync_set_cache(
result=complete_streaming_response,
kwargs=self.request_kwargs,
)

View file

@ -938,19 +938,6 @@ class Logging:
else:
callbacks = litellm.success_callback
## STREAMING CACHING ##
if "cache" in callbacks and litellm.cache is not None:
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
print_verbose("success_callback: reaches cache for logging!")
kwargs = self.model_call_details
if self.stream and _caching_complete_streaming_response is not None:
print_verbose(
"success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache"
)
result = _caching_complete_streaming_response
# only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs)
## REDACT MESSAGES ##
result = redact_message_input_output_from_logging(
model_call_details=(

View file

@ -765,7 +765,9 @@ def client(original_function):
print_args_passed_to_litellm(original_function, args, kwargs)
start_time = datetime.datetime.now()
result = None
logging_obj = kwargs.get("litellm_logging_obj", None)
logging_obj: Optional[LiteLLMLoggingObject] = kwargs.get(
"litellm_logging_obj", None
)
# only set litellm_call_id if its not in kwargs
call_type = original_function.__name__
@ -787,6 +789,12 @@ def client(original_function):
original_function.__name__, rules_obj, start_time, *args, **kwargs
)
kwargs["litellm_logging_obj"] = logging_obj
_llm_caching_handler: LLMCachingHandler = LLMCachingHandler(
original_function=original_function,
request_kwargs=kwargs,
start_time=start_time,
)
logging_obj._llm_caching_handler = _llm_caching_handler
# CHECK FOR 'os.environ/' in kwargs
for k, v in kwargs.items():
@ -1013,12 +1021,11 @@ def client(original_function):
)
# [OPTIONAL] ADD TO CACHE
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
and call_type in litellm.cache.supported_call_types
) and (kwargs.get("cache", {}).get("no-store", False) is not True):
litellm.cache.add_cache(result, *args, **kwargs)
_llm_caching_handler._sync_set_cache(
result=result,
args=args,
kwargs=kwargs,
)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
verbose_logger.info("Wrapper: Completed Call, calling success_handler")
@ -7886,7 +7893,10 @@ class CustomStreamWrapper:
"""
self.logging_loop = loop
def run_success_logging_in_thread(self, processed_chunk, cache_hit: bool):
def run_success_logging_and_cache_storage(self, processed_chunk, cache_hit: bool):
"""
Runs success logging in a thread and adds the response to the cache
"""
if litellm.disable_streaming_logging is True:
"""
[NOT RECOMMENDED]
@ -7914,6 +7924,12 @@ class CustomStreamWrapper:
## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk, None, None, cache_hit)
## Sync store in cache
if self.logging_obj._llm_caching_handler is not None:
self.logging_obj._llm_caching_handler._sync_add_streaming_response_to_cache(
processed_chunk
)
def finish_reason_handler(self):
model_response = self.model_response_creator()
if self.received_finish_reason is not None:
@ -7960,7 +7976,7 @@ class CustomStreamWrapper:
continue
## LOGGING
threading.Thread(
target=self.run_success_logging_in_thread,
target=self.run_success_logging_and_cache_storage,
args=(response, cache_hit),
).start() # log response
choice = response.choices[0]
@ -8028,7 +8044,7 @@ class CustomStreamWrapper:
processed_chunk._hidden_params["usage"] = usage
## LOGGING
threading.Thread(
target=self.run_success_logging_in_thread,
target=self.run_success_logging_and_cache_storage,
args=(processed_chunk, cache_hit),
).start() # log response
return processed_chunk