diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 479b3bd1f..adf224152 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -67,6 +67,7 @@ class LLMCachingHandler: start_time: datetime.datetime, ): self.async_streaming_chunks: List[ModelResponse] = [] + self.sync_streaming_chunks: List[ModelResponse] = [] self.request_kwargs = request_kwargs self.original_function = original_function self.start_time = start_time @@ -470,12 +471,11 @@ class LLMCachingHandler: None """ args = args or () + if litellm.cache is None: + return # [OPTIONAL] ADD TO CACHE - if ( - (litellm.cache is not None) - and litellm.cache.supported_call_types is not None - and (str(original_function.__name__) in litellm.cache.supported_call_types) - and (kwargs.get("cache", {}).get("no-store", False) is not True) + if self._should_store_result_in_cache( + original_function=original_function, kwargs=kwargs ): if ( isinstance(result, litellm.ModelResponse) @@ -509,6 +509,42 @@ class LLMCachingHandler: litellm.cache.async_add_cache(result, *args, **kwargs) ) + def _sync_set_cache( + self, + result: Any, + kwargs: Dict[str, Any], + args: Optional[Tuple[Any, ...]] = None, + ): + """ + Sync internal method to add the result to the cache + """ + if litellm.cache is None: + return + + args = args or () + if self._should_store_result_in_cache( + original_function=self.original_function, kwargs=kwargs + ): + litellm.cache.add_cache(result, *args, **kwargs) + + return + + def _should_store_result_in_cache( + self, original_function: Callable, kwargs: Dict[str, Any] + ) -> bool: + """ + Helper function to determine if the result should be stored in the cache. + + Returns: + bool: True if the result should be stored in the cache, False otherwise. + """ + return ( + (litellm.cache is not None) + and litellm.cache.supported_call_types is not None + and (str(original_function.__name__) in litellm.cache.supported_call_types) + and (kwargs.get("cache", {}).get("no-store", False) is not True) + ) + async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse): """ Internal method to add the streaming response to the cache @@ -536,3 +572,25 @@ class LLMCachingHandler: original_function=self.original_function, kwargs=self.request_kwargs, ) + + def _sync_add_streaming_response_to_cache(self, processed_chunk: ModelResponse): + """ + Sync internal method to add the streaming response to the cache + """ + complete_streaming_response: Optional[ + Union[ModelResponse, TextCompletionResponse] + ] = _assemble_complete_response_from_streaming_chunks( + result=processed_chunk, + start_time=self.start_time, + end_time=datetime.datetime.now(), + request_kwargs=self.request_kwargs, + streaming_chunks=self.sync_streaming_chunks, + is_async=False, + ) + + # if a complete_streaming_response is assembled, add it to the cache + if complete_streaming_response is not None: + self._sync_set_cache( + result=complete_streaming_response, + kwargs=self.request_kwargs, + ) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index f7afbb7b8..8b6dbc1f9 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -938,19 +938,6 @@ class Logging: else: callbacks = litellm.success_callback - ## STREAMING CACHING ## - if "cache" in callbacks and litellm.cache is not None: - # this only logs streaming once, complete_streaming_response exists i.e when stream ends - print_verbose("success_callback: reaches cache for logging!") - kwargs = self.model_call_details - if self.stream and _caching_complete_streaming_response is not None: - print_verbose( - "success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache" - ) - result = _caching_complete_streaming_response - # only add to cache once we have a complete streaming response - litellm.cache.add_cache(result, **kwargs) - ## REDACT MESSAGES ## result = redact_message_input_output_from_logging( model_call_details=( diff --git a/litellm/utils.py b/litellm/utils.py index 085fe7116..fd18ff7f1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -765,7 +765,9 @@ def client(original_function): print_args_passed_to_litellm(original_function, args, kwargs) start_time = datetime.datetime.now() result = None - logging_obj = kwargs.get("litellm_logging_obj", None) + logging_obj: Optional[LiteLLMLoggingObject] = kwargs.get( + "litellm_logging_obj", None + ) # only set litellm_call_id if its not in kwargs call_type = original_function.__name__ @@ -787,6 +789,12 @@ def client(original_function): original_function.__name__, rules_obj, start_time, *args, **kwargs ) kwargs["litellm_logging_obj"] = logging_obj + _llm_caching_handler: LLMCachingHandler = LLMCachingHandler( + original_function=original_function, + request_kwargs=kwargs, + start_time=start_time, + ) + logging_obj._llm_caching_handler = _llm_caching_handler # CHECK FOR 'os.environ/' in kwargs for k, v in kwargs.items(): @@ -1013,12 +1021,11 @@ def client(original_function): ) # [OPTIONAL] ADD TO CACHE - if ( - litellm.cache is not None - and litellm.cache.supported_call_types is not None - and call_type in litellm.cache.supported_call_types - ) and (kwargs.get("cache", {}).get("no-store", False) is not True): - litellm.cache.add_cache(result, *args, **kwargs) + _llm_caching_handler._sync_set_cache( + result=result, + args=args, + kwargs=kwargs, + ) # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated verbose_logger.info("Wrapper: Completed Call, calling success_handler") @@ -7886,7 +7893,10 @@ class CustomStreamWrapper: """ self.logging_loop = loop - def run_success_logging_in_thread(self, processed_chunk, cache_hit: bool): + def run_success_logging_and_cache_storage(self, processed_chunk, cache_hit: bool): + """ + Runs success logging in a thread and adds the response to the cache + """ if litellm.disable_streaming_logging is True: """ [NOT RECOMMENDED] @@ -7914,6 +7924,12 @@ class CustomStreamWrapper: ## SYNC LOGGING self.logging_obj.success_handler(processed_chunk, None, None, cache_hit) + ## Sync store in cache + if self.logging_obj._llm_caching_handler is not None: + self.logging_obj._llm_caching_handler._sync_add_streaming_response_to_cache( + processed_chunk + ) + def finish_reason_handler(self): model_response = self.model_response_creator() if self.received_finish_reason is not None: @@ -7960,7 +7976,7 @@ class CustomStreamWrapper: continue ## LOGGING threading.Thread( - target=self.run_success_logging_in_thread, + target=self.run_success_logging_and_cache_storage, args=(response, cache_hit), ).start() # log response choice = response.choices[0] @@ -8028,7 +8044,7 @@ class CustomStreamWrapper: processed_chunk._hidden_params["usage"] = usage ## LOGGING threading.Thread( - target=self.run_success_logging_in_thread, + target=self.run_success_logging_and_cache_storage, args=(processed_chunk, cache_hit), ).start() # log response return processed_chunk