forked from phoenix/litellm-mirror
(refactor caching) use LLMCachingHandler for caching streaming responses (#6210)
* use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * refactor async set stream cache * fix linting
This commit is contained in:
parent
78f3228e17
commit
d1bef4ad81
3 changed files with 75 additions and 37 deletions
|
@ -13,12 +13,12 @@ In each method it will call the appropriate method from caching.py
|
|||
import asyncio
|
||||
import datetime
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.caching.caching import (
|
||||
Cache,
|
||||
QdrantSemanticCache,
|
||||
|
@ -57,7 +57,16 @@ class CachingHandlerResponse(BaseModel):
|
|||
|
||||
|
||||
class LLMCachingHandler:
|
||||
def __init__(self):
|
||||
def __init__(
|
||||
self,
|
||||
original_function: Callable,
|
||||
request_kwargs: Dict[str, Any],
|
||||
start_time: datetime.datetime,
|
||||
):
|
||||
self.async_streaming_chunks: List[ModelResponse] = []
|
||||
self.request_kwargs = request_kwargs
|
||||
self.original_function = original_function
|
||||
self.start_time = start_time
|
||||
pass
|
||||
|
||||
async def _async_get_cache(
|
||||
|
@ -438,3 +447,45 @@ class LLMCachingHandler:
|
|||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache(result, *args, **kwargs)
|
||||
)
|
||||
|
||||
async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
|
||||
"""
|
||||
Internal method to add the streaming response to the cache
|
||||
|
||||
|
||||
- If 'streaming_chunk' has a 'finish_reason' then assemble a litellm.ModelResponse object
|
||||
- Else append the chunk to self.async_streaming_chunks
|
||||
|
||||
"""
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
] = None
|
||||
if (
|
||||
processed_chunk.choices[0].finish_reason is not None
|
||||
): # if it's the last chunk
|
||||
self.async_streaming_chunks.append(processed_chunk)
|
||||
try:
|
||||
end_time: datetime.datetime = datetime.datetime.now()
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
self.async_streaming_chunks,
|
||||
messages=self.request_kwargs.get("messages", None),
|
||||
start_time=self.start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Error occurred building stream chunk in success logging: {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
complete_streaming_response = None
|
||||
else:
|
||||
self.async_streaming_chunks.append(processed_chunk)
|
||||
|
||||
# if a complete_streaming_response is assembled, add it to the cache
|
||||
if complete_streaming_response is not None:
|
||||
await self._async_set_cache(
|
||||
result=complete_streaming_response,
|
||||
original_function=self.original_function,
|
||||
kwargs=self.request_kwargs,
|
||||
)
|
||||
|
|
|
@ -24,6 +24,7 @@ from litellm import (
|
|||
verbose_logger,
|
||||
)
|
||||
from litellm.caching.caching import DualCache, InMemoryCache, S3Cache
|
||||
from litellm.caching.caching_handler import LLMCachingHandler
|
||||
from litellm.cost_calculator import _select_model_name_for_cost_calc
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -271,6 +272,7 @@ class Logging:
|
|||
|
||||
## TIME TO FIRST TOKEN LOGGING ##
|
||||
self.completion_start_time: Optional[datetime.datetime] = None
|
||||
self._llm_caching_handler: Optional[LLMCachingHandler] = None
|
||||
|
||||
def process_dynamic_callbacks(self):
|
||||
"""
|
||||
|
@ -1625,35 +1627,6 @@ class Logging:
|
|||
if kwargs.get("no-log", False) is True:
|
||||
print_verbose("no-log request, skipping logging")
|
||||
continue
|
||||
if (
|
||||
callback == "cache"
|
||||
and litellm.cache is not None
|
||||
and self.model_call_details.get("litellm_params", {}).get(
|
||||
"acompletion", False
|
||||
)
|
||||
is True
|
||||
):
|
||||
# set_cache once complete streaming response is built
|
||||
print_verbose("async success_callback: reaches cache for logging!")
|
||||
kwargs = self.model_call_details
|
||||
if self.stream:
|
||||
if "async_complete_streaming_response" not in kwargs:
|
||||
print_verbose(
|
||||
f"async success_callback: reaches cache for logging, there is no async_complete_streaming_response. Kwargs={kwargs}\n\n"
|
||||
)
|
||||
pass
|
||||
else:
|
||||
print_verbose(
|
||||
"async success_callback: reaches cache for logging, there is a async_complete_streaming_response. Adding to cache"
|
||||
)
|
||||
result = kwargs["async_complete_streaming_response"]
|
||||
# only add to cache once we have a complete streaming response
|
||||
if litellm.cache is not None and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
):
|
||||
await litellm.cache.async_add_cache(result, **kwargs)
|
||||
else:
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
if callback == "openmeter" and openMeterLogger is not None:
|
||||
if self.stream is True:
|
||||
if (
|
||||
|
|
|
@ -58,8 +58,6 @@ import litellm.litellm_core_utils.audio_utils.utils
|
|||
import litellm.litellm_core_utils.json_validation_rule
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
|
||||
|
||||
_llm_caching_handler = LLMCachingHandler()
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.exception_mapping_utils import (
|
||||
|
@ -73,6 +71,7 @@ from litellm.litellm_core_utils.get_llm_provider_logic import (
|
|||
)
|
||||
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
|
||||
from litellm.litellm_core_utils.redact_messages import (
|
||||
LiteLLMLoggingObject,
|
||||
redact_message_input_output_from_logging,
|
||||
)
|
||||
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
||||
|
@ -1095,7 +1094,14 @@ def client(original_function):
|
|||
print_args_passed_to_litellm(original_function, args, kwargs)
|
||||
start_time = datetime.datetime.now()
|
||||
result = None
|
||||
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||
logging_obj: Optional[LiteLLMLoggingObject] = kwargs.get(
|
||||
"litellm_logging_obj", None
|
||||
)
|
||||
_llm_caching_handler: LLMCachingHandler = LLMCachingHandler(
|
||||
original_function=original_function,
|
||||
request_kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
)
|
||||
# only set litellm_call_id if its not in kwargs
|
||||
call_type = original_function.__name__
|
||||
if "litellm_call_id" not in kwargs:
|
||||
|
@ -1117,7 +1123,7 @@ def client(original_function):
|
|||
original_function.__name__, rules_obj, start_time, *args, **kwargs
|
||||
)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
|
||||
logging_obj._llm_caching_handler = _llm_caching_handler
|
||||
# [OPTIONAL] CHECK BUDGET
|
||||
if litellm.max_budget:
|
||||
if litellm._current_cost > litellm.max_budget:
|
||||
|
@ -6254,7 +6260,7 @@ class CustomStreamWrapper:
|
|||
self.model = model
|
||||
self.make_call = make_call
|
||||
self.custom_llm_provider = custom_llm_provider
|
||||
self.logging_obj = logging_obj
|
||||
self.logging_obj: LiteLLMLoggingObject = logging_obj
|
||||
self.completion_stream = completion_stream
|
||||
self.sent_first_chunk = False
|
||||
self.sent_last_chunk = False
|
||||
|
@ -8067,6 +8073,14 @@ class CustomStreamWrapper:
|
|||
processed_chunk, cache_hit=cache_hit
|
||||
)
|
||||
)
|
||||
|
||||
if self.logging_obj._llm_caching_handler is not None:
|
||||
asyncio.create_task(
|
||||
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
|
||||
processed_chunk=processed_chunk,
|
||||
)
|
||||
)
|
||||
|
||||
choice = processed_chunk.choices[0]
|
||||
if isinstance(choice, StreamingChoices):
|
||||
self.response_uptil_now += choice.delta.get("content", "") or ""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue