mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(refactor caching) use LLMCachingHandler for caching streaming responses (#6210)
* use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * refactor async set stream cache * fix linting
This commit is contained in:
parent
78f3228e17
commit
d1bef4ad81
3 changed files with 75 additions and 37 deletions
|
@ -58,8 +58,6 @@ import litellm.litellm_core_utils.audio_utils.utils
|
|||
import litellm.litellm_core_utils.json_validation_rule
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
|
||||
|
||||
_llm_caching_handler = LLMCachingHandler()
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.exception_mapping_utils import (
|
||||
|
@ -73,6 +71,7 @@ from litellm.litellm_core_utils.get_llm_provider_logic import (
|
|||
)
|
||||
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
|
||||
from litellm.litellm_core_utils.redact_messages import (
|
||||
LiteLLMLoggingObject,
|
||||
redact_message_input_output_from_logging,
|
||||
)
|
||||
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
||||
|
@ -1095,7 +1094,14 @@ def client(original_function):
|
|||
print_args_passed_to_litellm(original_function, args, kwargs)
|
||||
start_time = datetime.datetime.now()
|
||||
result = None
|
||||
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||
logging_obj: Optional[LiteLLMLoggingObject] = kwargs.get(
|
||||
"litellm_logging_obj", None
|
||||
)
|
||||
_llm_caching_handler: LLMCachingHandler = LLMCachingHandler(
|
||||
original_function=original_function,
|
||||
request_kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
)
|
||||
# only set litellm_call_id if its not in kwargs
|
||||
call_type = original_function.__name__
|
||||
if "litellm_call_id" not in kwargs:
|
||||
|
@ -1117,7 +1123,7 @@ def client(original_function):
|
|||
original_function.__name__, rules_obj, start_time, *args, **kwargs
|
||||
)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
|
||||
logging_obj._llm_caching_handler = _llm_caching_handler
|
||||
# [OPTIONAL] CHECK BUDGET
|
||||
if litellm.max_budget:
|
||||
if litellm._current_cost > litellm.max_budget:
|
||||
|
@ -6254,7 +6260,7 @@ class CustomStreamWrapper:
|
|||
self.model = model
|
||||
self.make_call = make_call
|
||||
self.custom_llm_provider = custom_llm_provider
|
||||
self.logging_obj = logging_obj
|
||||
self.logging_obj: LiteLLMLoggingObject = logging_obj
|
||||
self.completion_stream = completion_stream
|
||||
self.sent_first_chunk = False
|
||||
self.sent_last_chunk = False
|
||||
|
@ -8067,6 +8073,14 @@ class CustomStreamWrapper:
|
|||
processed_chunk, cache_hit=cache_hit
|
||||
)
|
||||
)
|
||||
|
||||
if self.logging_obj._llm_caching_handler is not None:
|
||||
asyncio.create_task(
|
||||
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
|
||||
processed_chunk=processed_chunk,
|
||||
)
|
||||
)
|
||||
|
||||
choice = processed_chunk.choices[0]
|
||||
if isinstance(choice, StreamingChoices):
|
||||
self.response_uptil_now += choice.delta.get("content", "") or ""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue