From d1bef4ad813dab345cc3993dcfc9e17dd3213fa0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 14 Oct 2024 17:46:45 +0530 Subject: [PATCH] (refactor caching) use LLMCachingHandler for caching streaming responses (#6210) * use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * refactor async set stream cache * fix linting --- litellm/caching/caching_handler.py | 57 ++++++++++++++++++- litellm/litellm_core_utils/litellm_logging.py | 31 +--------- litellm/utils.py | 24 ++++++-- 3 files changed, 75 insertions(+), 37 deletions(-) diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 58de0b42f..11f055ffe 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -13,12 +13,12 @@ In each method it will call the appropriate method from caching.py import asyncio import datetime import threading -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from pydantic import BaseModel import litellm -from litellm._logging import print_verbose +from litellm._logging import print_verbose, verbose_logger from litellm.caching.caching import ( Cache, QdrantSemanticCache, @@ -57,7 +57,16 @@ class CachingHandlerResponse(BaseModel): class LLMCachingHandler: - def __init__(self): + def __init__( + self, + original_function: Callable, + request_kwargs: Dict[str, Any], + start_time: datetime.datetime, + ): + self.async_streaming_chunks: List[ModelResponse] = [] + self.request_kwargs = request_kwargs + self.original_function = original_function + self.start_time = start_time pass async def _async_get_cache( @@ -438,3 +447,45 @@ class LLMCachingHandler: asyncio.create_task( litellm.cache.async_add_cache(result, *args, **kwargs) ) + + async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse): + """ + Internal method to add the streaming response to the cache + + + - If 'streaming_chunk' has a 'finish_reason' then assemble a litellm.ModelResponse object + - Else append the chunk to self.async_streaming_chunks + + """ + complete_streaming_response: Optional[ + Union[ModelResponse, TextCompletionResponse] + ] = None + if ( + processed_chunk.choices[0].finish_reason is not None + ): # if it's the last chunk + self.async_streaming_chunks.append(processed_chunk) + try: + end_time: datetime.datetime = datetime.datetime.now() + complete_streaming_response = litellm.stream_chunk_builder( + self.async_streaming_chunks, + messages=self.request_kwargs.get("messages", None), + start_time=self.start_time, + end_time=end_time, + ) + except Exception as e: + verbose_logger.exception( + "Error occurred building stream chunk in success logging: {}".format( + str(e) + ) + ) + complete_streaming_response = None + else: + self.async_streaming_chunks.append(processed_chunk) + + # if a complete_streaming_response is assembled, add it to the cache + if complete_streaming_response is not None: + await self._async_set_cache( + result=complete_streaming_response, + original_function=self.original_function, + kwargs=self.request_kwargs, + ) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 4bffa94b3..1d4a3e3cc 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -24,6 +24,7 @@ from litellm import ( verbose_logger, ) from litellm.caching.caching import DualCache, InMemoryCache, S3Cache +from litellm.caching.caching_handler import LLMCachingHandler from litellm.cost_calculator import _select_model_name_for_cost_calc from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger @@ -271,6 +272,7 @@ class Logging: ## TIME TO FIRST TOKEN LOGGING ## self.completion_start_time: Optional[datetime.datetime] = None + self._llm_caching_handler: Optional[LLMCachingHandler] = None def process_dynamic_callbacks(self): """ @@ -1625,35 +1627,6 @@ class Logging: if kwargs.get("no-log", False) is True: print_verbose("no-log request, skipping logging") continue - if ( - callback == "cache" - and litellm.cache is not None - and self.model_call_details.get("litellm_params", {}).get( - "acompletion", False - ) - is True - ): - # set_cache once complete streaming response is built - print_verbose("async success_callback: reaches cache for logging!") - kwargs = self.model_call_details - if self.stream: - if "async_complete_streaming_response" not in kwargs: - print_verbose( - f"async success_callback: reaches cache for logging, there is no async_complete_streaming_response. Kwargs={kwargs}\n\n" - ) - pass - else: - print_verbose( - "async success_callback: reaches cache for logging, there is a async_complete_streaming_response. Adding to cache" - ) - result = kwargs["async_complete_streaming_response"] - # only add to cache once we have a complete streaming response - if litellm.cache is not None and not isinstance( - litellm.cache.cache, S3Cache - ): - await litellm.cache.async_add_cache(result, **kwargs) - else: - litellm.cache.add_cache(result, **kwargs) if callback == "openmeter" and openMeterLogger is not None: if self.stream is True: if ( diff --git a/litellm/utils.py b/litellm/utils.py index a79a16a58..2457bdf4c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -58,8 +58,6 @@ import litellm.litellm_core_utils.audio_utils.utils import litellm.litellm_core_utils.json_validation_rule from litellm.caching.caching import DualCache from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler - -_llm_caching_handler = LLMCachingHandler() from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.exception_mapping_utils import ( @@ -73,6 +71,7 @@ from litellm.litellm_core_utils.get_llm_provider_logic import ( ) from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe from litellm.litellm_core_utils.redact_messages import ( + LiteLLMLoggingObject, redact_message_input_output_from_logging, ) from litellm.litellm_core_utils.token_counter import get_modified_max_tokens @@ -1095,7 +1094,14 @@ def client(original_function): print_args_passed_to_litellm(original_function, args, kwargs) start_time = datetime.datetime.now() result = None - logging_obj = kwargs.get("litellm_logging_obj", None) + logging_obj: Optional[LiteLLMLoggingObject] = kwargs.get( + "litellm_logging_obj", None + ) + _llm_caching_handler: LLMCachingHandler = LLMCachingHandler( + original_function=original_function, + request_kwargs=kwargs, + start_time=start_time, + ) # only set litellm_call_id if its not in kwargs call_type = original_function.__name__ if "litellm_call_id" not in kwargs: @@ -1117,7 +1123,7 @@ def client(original_function): original_function.__name__, rules_obj, start_time, *args, **kwargs ) kwargs["litellm_logging_obj"] = logging_obj - + logging_obj._llm_caching_handler = _llm_caching_handler # [OPTIONAL] CHECK BUDGET if litellm.max_budget: if litellm._current_cost > litellm.max_budget: @@ -6254,7 +6260,7 @@ class CustomStreamWrapper: self.model = model self.make_call = make_call self.custom_llm_provider = custom_llm_provider - self.logging_obj = logging_obj + self.logging_obj: LiteLLMLoggingObject = logging_obj self.completion_stream = completion_stream self.sent_first_chunk = False self.sent_last_chunk = False @@ -8067,6 +8073,14 @@ class CustomStreamWrapper: processed_chunk, cache_hit=cache_hit ) ) + + if self.logging_obj._llm_caching_handler is not None: + asyncio.create_task( + self.logging_obj._llm_caching_handler._add_streaming_response_to_cache( + processed_chunk=processed_chunk, + ) + ) + choice = processed_chunk.choices[0] if isinstance(choice, StreamingChoices): self.response_uptil_now += choice.delta.get("content", "") or ""