(refactor caching) use LLMCachingHandler for caching streaming responses (#6210)

* use folder for caching

* fix importing caching

* fix clickhouse pyright

* fix linting

* fix correctly pass kwargs and args

* fix test case for embedding

* fix linting

* fix embedding caching logic

* fix refactor handle utils.py

* refactor async set stream cache

* fix linting
This commit is contained in:
Ishaan Jaff 2024-10-14 17:46:45 +05:30 committed by GitHub
parent 78f3228e17
commit d1bef4ad81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 75 additions and 37 deletions

View file

@ -58,8 +58,6 @@ import litellm.litellm_core_utils.audio_utils.utils
import litellm.litellm_core_utils.json_validation_rule
from litellm.caching.caching import DualCache
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
_llm_caching_handler = LLMCachingHandler()
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.exception_mapping_utils import (
@ -73,6 +71,7 @@ from litellm.litellm_core_utils.get_llm_provider_logic import (
)
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
from litellm.litellm_core_utils.redact_messages import (
LiteLLMLoggingObject,
redact_message_input_output_from_logging,
)
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
@ -1095,7 +1094,14 @@ def client(original_function):
print_args_passed_to_litellm(original_function, args, kwargs)
start_time = datetime.datetime.now()
result = None
logging_obj = kwargs.get("litellm_logging_obj", None)
logging_obj: Optional[LiteLLMLoggingObject] = kwargs.get(
"litellm_logging_obj", None
)
_llm_caching_handler: LLMCachingHandler = LLMCachingHandler(
original_function=original_function,
request_kwargs=kwargs,
start_time=start_time,
)
# only set litellm_call_id if its not in kwargs
call_type = original_function.__name__
if "litellm_call_id" not in kwargs:
@ -1117,7 +1123,7 @@ def client(original_function):
original_function.__name__, rules_obj, start_time, *args, **kwargs
)
kwargs["litellm_logging_obj"] = logging_obj
logging_obj._llm_caching_handler = _llm_caching_handler
# [OPTIONAL] CHECK BUDGET
if litellm.max_budget:
if litellm._current_cost > litellm.max_budget:
@ -6254,7 +6260,7 @@ class CustomStreamWrapper:
self.model = model
self.make_call = make_call
self.custom_llm_provider = custom_llm_provider
self.logging_obj = logging_obj
self.logging_obj: LiteLLMLoggingObject = logging_obj
self.completion_stream = completion_stream
self.sent_first_chunk = False
self.sent_last_chunk = False
@ -8067,6 +8073,14 @@ class CustomStreamWrapper:
processed_chunk, cache_hit=cache_hit
)
)
if self.logging_obj._llm_caching_handler is not None:
asyncio.create_task(
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
processed_chunk=processed_chunk,
)
)
choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or ""