diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index adf224152..ae07066df 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -13,7 +13,18 @@ In each method it will call the appropriate method from caching.py import asyncio import datetime import threading -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Callable, + Dict, + Generator, + List, + Optional, + Tuple, + Union, +) from pydantic import BaseModel @@ -41,8 +52,10 @@ from litellm.types.utils import ( if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.utils import CustomStreamWrapper else: LiteLLMLoggingObj = Any + CustomStreamWrapper = Any class CachingHandlerResponse(BaseModel): @@ -108,6 +121,7 @@ class LLMCachingHandler: args = args or () final_embedding_cached_response: Optional[EmbeddingResponse] = None + embedding_all_elements_cache_hit: bool = False cached_result: Optional[Any] = None if ( (kwargs.get("caching", None) is None and litellm.cache is not None) @@ -115,16 +129,10 @@ class LLMCachingHandler: ) and ( kwargs.get("cache", {}).get("no-cache", False) is not True ): # allow users to control returning cached responses from the completion function - # checking cache - print_verbose("INSIDE CHECKING CACHE") - if ( - litellm.cache is not None - and litellm.cache.supported_call_types is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types + if litellm.cache is not None and self._is_call_type_supported_by_cache( + original_function=original_function ): print_verbose("Checking Cache") - cached_result = await self._retrieve_from_cache( call_type=call_type, kwargs=kwargs, @@ -135,42 +143,20 @@ class LLMCachingHandler: print_verbose("Cache Hit!") cache_hit = True end_time = datetime.datetime.now() - ( - model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( + model, _, _, _ = litellm.get_llm_provider( model=model, custom_llm_provider=kwargs.get("custom_llm_provider", None), api_base=kwargs.get("api_base", None), api_key=kwargs.get("api_key", None), ) - print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" - ) - logging_obj.update_environment_variables( + self._update_litellm_logging_obj_environment( + logging_obj=logging_obj, model=model, - user=kwargs.get("user", None), - optional_params={}, - litellm_params={ - "logger_fn": kwargs.get("logger_fn", None), - "acompletion": True, - "metadata": kwargs.get("metadata", {}), - "model_info": kwargs.get("model_info", {}), - "proxy_server_request": kwargs.get( - "proxy_server_request", None - ), - "preset_cache_key": kwargs.get("preset_cache_key", None), - "stream_response": kwargs.get("stream_response", {}), - "api_base": kwargs.get("api_base", ""), - }, - input=kwargs.get("messages", ""), - api_key=kwargs.get("api_key", None), - original_response=str(cached_result), - additional_args=None, - stream=kwargs.get("stream", False), + kwargs=kwargs, + cached_result=cached_result, + is_async=True, ) + call_type = original_function.__name__ cached_result = self._convert_cached_result_to_model_response( @@ -184,15 +170,13 @@ class LLMCachingHandler: ) if kwargs.get("stream", False) is False: # LOG SUCCESS - asyncio.create_task( - logging_obj.async_success_handler( - cached_result, start_time, end_time, cache_hit - ) + self._async_log_cache_hit_on_callbacks( + logging_obj=logging_obj, + cached_result=cached_result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() cache_key = kwargs.get("preset_cache_key", None) if ( isinstance(cached_result, BaseModel) @@ -209,101 +193,261 @@ class LLMCachingHandler: litellm.cache.cache, S3Cache ) # s3 doesn't support bulk writing. Exclude. ): - remaining_list = [] - non_null_list = [] - for idx, cr in enumerate(cached_result): - if cr is None: - remaining_list.append(kwargs["input"][idx]) - else: - non_null_list.append((idx, cr)) - original_kwargs_input = kwargs["input"] - kwargs["input"] = remaining_list - if len(non_null_list) > 0: - print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}") - final_embedding_cached_response = EmbeddingResponse( - model=kwargs.get("model"), - data=[None] * len(original_kwargs_input), - ) - final_embedding_cached_response._hidden_params["cache_hit"] = ( - True - ) - - for val in non_null_list: - idx, cr = val # (idx, cr) tuple - if cr is not None: - final_embedding_cached_response.data[idx] = Embedding( - embedding=cr["embedding"], - index=idx, - object="embedding", - ) - if len(remaining_list) == 0: - # LOG SUCCESS - cache_hit = True - end_time = datetime.datetime.now() - ( - model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( - model=model, - custom_llm_provider=kwargs.get("custom_llm_provider", None), - api_base=kwargs.get("api_base", None), - api_key=kwargs.get("api_key", None), - ) - print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" - ) - logging_obj.update_environment_variables( - model=model, - user=kwargs.get("user", None), - optional_params={}, - litellm_params={ - "logger_fn": kwargs.get("logger_fn", None), - "acompletion": True, - "metadata": kwargs.get("metadata", {}), - "model_info": kwargs.get("model_info", {}), - "proxy_server_request": kwargs.get( - "proxy_server_request", None - ), - "preset_cache_key": kwargs.get( - "preset_cache_key", None - ), - "stream_response": kwargs.get("stream_response", {}), - "api_base": "", - }, - input=kwargs.get("messages", ""), - api_key=kwargs.get("api_key", None), - original_response=str(final_embedding_cached_response), - additional_args=None, - stream=kwargs.get("stream", False), - ) - asyncio.create_task( - logging_obj.async_success_handler( - final_embedding_cached_response, - start_time, - end_time, - cache_hit, - ) - ) - threading.Thread( - target=logging_obj.success_handler, - args=( - final_embedding_cached_response, - start_time, - end_time, - cache_hit, - ), - ).start() - return CachingHandlerResponse( - final_embedding_cached_response=final_embedding_cached_response, - embedding_all_elements_cache_hit=True, - ) + ( + final_embedding_cached_response, + embedding_all_elements_cache_hit, + ) = self._process_async_embedding_cached_response( + final_embedding_cached_response=final_embedding_cached_response, + cached_result=cached_result, + kwargs=kwargs, + logging_obj=logging_obj, + start_time=start_time, + model=model, + ) + return CachingHandlerResponse( + final_embedding_cached_response=final_embedding_cached_response, + embedding_all_elements_cache_hit=embedding_all_elements_cache_hit, + ) return CachingHandlerResponse( cached_result=cached_result, final_embedding_cached_response=final_embedding_cached_response, ) + def _sync_get_cache( + self, + model: str, + original_function: Callable, + logging_obj: LiteLLMLoggingObj, + start_time: datetime.datetime, + call_type: str, + kwargs: Dict[str, Any], + args: Optional[Tuple[Any, ...]] = None, + ) -> CachingHandlerResponse: + from litellm.utils import CustomStreamWrapper + + args = args or () + cached_result: Optional[Any] = None + if litellm.cache is not None and self._is_call_type_supported_by_cache( + original_function=original_function + ): + print_verbose("Checking Cache") + preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) + kwargs["preset_cache_key"] = ( + preset_cache_key # for streaming calls, we need to pass the preset_cache_key + ) + cached_result = litellm.cache.get_cache(*args, **kwargs) + + if cached_result is not None: + if "detail" in cached_result: + # implies an error occurred + pass + else: + call_type = original_function.__name__ + + cached_result = self._convert_cached_result_to_model_response( + cached_result=cached_result, + call_type=call_type, + kwargs=kwargs, + logging_obj=logging_obj, + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + args=args, + ) + + # LOG SUCCESS + cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model or "", + custom_llm_provider=kwargs.get("custom_llm_provider", None), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + self._update_litellm_logging_obj_environment( + logging_obj=logging_obj, + model=model, + kwargs=kwargs, + cached_result=cached_result, + is_async=False, + ) + + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() + cache_key = kwargs.get("preset_cache_key", None) + if ( + isinstance(cached_result, BaseModel) + or isinstance(cached_result, CustomStreamWrapper) + ) and hasattr(cached_result, "_hidden_params"): + cached_result._hidden_params["cache_key"] = cache_key # type: ignore + return CachingHandlerResponse(cached_result=cached_result) + return CachingHandlerResponse(cached_result=cached_result) + + def _process_async_embedding_cached_response( + self, + final_embedding_cached_response: Optional[EmbeddingResponse], + cached_result: List[Optional[Dict[str, Any]]], + kwargs: Dict[str, Any], + logging_obj: LiteLLMLoggingObj, + start_time: datetime.datetime, + model: str, + ) -> Tuple[Optional[EmbeddingResponse], bool]: + """ + Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit + + For embedding responses, there can be a cache hit for some of the inputs in the list and a cache miss for others + This function processes the cached embedding responses and returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit + + Args: + final_embedding_cached_response: Optional[EmbeddingResponse]: + cached_result: List[Optional[Dict[str, Any]]]: + kwargs: Dict[str, Any]: + logging_obj: LiteLLMLoggingObj: + start_time: datetime.datetime: + model: str: + + Returns: + Tuple[Optional[EmbeddingResponse], bool]: + Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit + + + """ + embedding_all_elements_cache_hit: bool = False + remaining_list = [] + non_null_list = [] + for idx, cr in enumerate(cached_result): + if cr is None: + remaining_list.append(kwargs["input"][idx]) + else: + non_null_list.append((idx, cr)) + original_kwargs_input = kwargs["input"] + kwargs["input"] = remaining_list + if len(non_null_list) > 0: + print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}") + final_embedding_cached_response = EmbeddingResponse( + model=kwargs.get("model"), + data=[None] * len(original_kwargs_input), + ) + final_embedding_cached_response._hidden_params["cache_hit"] = True + + for val in non_null_list: + idx, cr = val # (idx, cr) tuple + if cr is not None: + final_embedding_cached_response.data[idx] = Embedding( + embedding=cr["embedding"], + index=idx, + object="embedding", + ) + if len(remaining_list) == 0: + # LOG SUCCESS + cache_hit = True + embedding_all_elements_cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + + self._update_litellm_logging_obj_environment( + logging_obj=logging_obj, + model=model, + kwargs=kwargs, + cached_result=final_embedding_cached_response, + is_async=True, + is_embedding=True, + ) + self._async_log_cache_hit_on_callbacks( + logging_obj=logging_obj, + cached_result=final_embedding_cached_response, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + ) + return final_embedding_cached_response, embedding_all_elements_cache_hit + return final_embedding_cached_response, embedding_all_elements_cache_hit + + def _combine_cached_embedding_response_with_api_result( + self, + _caching_handler_response: CachingHandlerResponse, + embedding_response: EmbeddingResponse, + start_time: datetime.datetime, + end_time: datetime.datetime, + ) -> EmbeddingResponse: + """ + Combines the cached embedding response with the API EmbeddingResponse + + For caching there can be a cache hit for some of the inputs in the list and a cache miss for others + This function combines the cached embedding response with the API EmbeddingResponse + + Args: + caching_handler_response: CachingHandlerResponse: + embedding_response: EmbeddingResponse: + + Returns: + EmbeddingResponse: + """ + if _caching_handler_response.final_embedding_cached_response is None: + return embedding_response + + idx = 0 + final_data_list = [] + for item in _caching_handler_response.final_embedding_cached_response.data: + if item is None and embedding_response.data is not None: + final_data_list.append(embedding_response.data[idx]) + idx += 1 + else: + final_data_list.append(item) + + _caching_handler_response.final_embedding_cached_response.data = final_data_list + _caching_handler_response.final_embedding_cached_response._hidden_params[ + "cache_hit" + ] = True + _caching_handler_response.final_embedding_cached_response._response_ms = ( + end_time - start_time + ).total_seconds() * 1000 + return _caching_handler_response.final_embedding_cached_response + + def _async_log_cache_hit_on_callbacks( + self, + logging_obj: LiteLLMLoggingObj, + cached_result: Any, + start_time: datetime.datetime, + end_time: datetime.datetime, + cache_hit: bool, + ): + """ + Helper function to log the success of a cached result on callbacks + + Args: + logging_obj (LiteLLMLoggingObj): The logging object. + cached_result: The cached result. + start_time (datetime): The start time of the operation. + end_time (datetime): The end time of the operation. + cache_hit (bool): Whether it was a cache hit. + """ + asyncio.create_task( + logging_obj.async_success_handler( + cached_result, start_time, end_time, cache_hit + ) + ) + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() + async def _retrieve_from_cache( self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...] ) -> Optional[Any]: @@ -385,57 +529,60 @@ class LLMCachingHandler: from litellm.utils import ( CustomStreamWrapper, convert_to_model_response_object, + convert_to_streaming_response, convert_to_streaming_response_async, ) - if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict): + if ( + call_type == CallTypes.acompletion.value + or call_type == CallTypes.completion.value + ) and isinstance(cached_result, dict): if kwargs.get("stream", False) is True: - cached_result = convert_to_streaming_response_async( - response_object=cached_result, - ) - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", + cached_result = self._convert_cached_stream_response( + cached_result=cached_result, + call_type=call_type, logging_obj=logging_obj, + model=model, ) else: cached_result = convert_to_model_response_object( response_object=cached_result, model_response_object=ModelResponse(), ) - if call_type == CallTypes.atext_completion.value and isinstance( - cached_result, dict - ): + if ( + call_type == CallTypes.atext_completion.value + or call_type == CallTypes.text_completion.value + ) and isinstance(cached_result, dict): if kwargs.get("stream", False) is True: - cached_result = convert_to_streaming_response_async( - response_object=cached_result, - ) - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", + cached_result = self._convert_cached_stream_response( + cached_result=cached_result, + call_type=call_type, logging_obj=logging_obj, + model=model, ) else: cached_result = TextCompletionResponse(**cached_result) - elif call_type == CallTypes.aembedding.value and isinstance( - cached_result, dict - ): + elif ( + call_type == CallTypes.aembedding.value + or call_type == CallTypes.embedding.value + ) and isinstance(cached_result, dict): cached_result = convert_to_model_response_object( response_object=cached_result, model_response_object=EmbeddingResponse(), response_type="embedding", ) - elif call_type == CallTypes.arerank.value and isinstance(cached_result, dict): + elif ( + call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value + ) and isinstance(cached_result, dict): cached_result = convert_to_model_response_object( response_object=cached_result, model_response_object=None, response_type="rerank", ) - elif call_type == CallTypes.atranscription.value and isinstance( - cached_result, dict - ): + elif ( + call_type == CallTypes.atranscription.value + or call_type == CallTypes.transcription.value + ) and isinstance(cached_result, dict): hidden_params = { "model": "whisper-1", "custom_llm_provider": custom_llm_provider, @@ -449,6 +596,38 @@ class LLMCachingHandler: ) return cached_result + def _convert_cached_stream_response( + self, + cached_result: Any, + call_type: str, + logging_obj: LiteLLMLoggingObj, + model: str, + ) -> CustomStreamWrapper: + from litellm.utils import ( + CustomStreamWrapper, + convert_to_streaming_response, + convert_to_streaming_response_async, + ) + + _stream_cached_result: Union[AsyncGenerator, Generator] + if ( + call_type == CallTypes.acompletion.value + or call_type == CallTypes.atext_completion.value + ): + _stream_cached_result = convert_to_streaming_response_async( + response_object=cached_result, + ) + else: + _stream_cached_result = convert_to_streaming_response( + response_object=cached_result, + ) + return CustomStreamWrapper( + completion_stream=_stream_cached_result, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + async def _async_set_cache( self, result: Any, @@ -545,6 +724,28 @@ class LLMCachingHandler: and (kwargs.get("cache", {}).get("no-store", False) is not True) ) + def _is_call_type_supported_by_cache( + self, + original_function: Callable, + ) -> bool: + """ + Helper function to determine if the call type is supported by the cache. + + call types are acompletion, aembedding, atext_completion, atranscription, arerank + + Defined on `litellm.types.utils.CallTypes` + + Returns: + bool: True if the call type is supported by the cache, False otherwise. + """ + if ( + litellm.cache is not None + and litellm.cache.supported_call_types is not None + and str(original_function.__name__) in litellm.cache.supported_call_types + ): + return True + return False + async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse): """ Internal method to add the streaming response to the cache @@ -594,3 +795,53 @@ class LLMCachingHandler: result=complete_streaming_response, kwargs=self.request_kwargs, ) + + def _update_litellm_logging_obj_environment( + self, + logging_obj: LiteLLMLoggingObj, + model: str, + kwargs: Dict[str, Any], + cached_result: Any, + is_async: bool, + is_embedding: bool = False, + ): + """ + Helper function to update the LiteLLMLoggingObj environment variables. + + Args: + logging_obj (LiteLLMLoggingObj): The logging object to update. + model (str): The model being used. + kwargs (Dict[str, Any]): The keyword arguments from the original function call. + cached_result (Any): The cached result to log. + is_async (bool): Whether the call is asynchronous or not. + is_embedding (bool): Whether the call is for embeddings or not. + + Returns: + None + """ + litellm_params = { + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": is_async, + "api_base": kwargs.get("api_base", ""), + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get("proxy_server_request", None), + "preset_cache_key": kwargs.get("preset_cache_key", None), + "stream_response": kwargs.get("stream_response", {}), + } + + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params=litellm_params, + input=( + kwargs.get("messages", "") + if not is_embedding + else kwargs.get("input", "") + ), + api_key=kwargs.get("api_key", None), + original_response=str(cached_result), + additional_args=None, + stream=kwargs.get("stream", False), + ) diff --git a/litellm/utils.py b/litellm/utils.py index fd18ff7f1..798bed1c6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -773,6 +773,8 @@ def client(original_function): call_type = original_function.__name__ if "litellm_call_id" not in kwargs: kwargs["litellm_call_id"] = str(uuid.uuid4()) + + model: Optional[str] = None try: model = args[0] if len(args) > 0 else kwargs["model"] except Exception: @@ -844,116 +846,20 @@ def client(original_function): ): # allow users to control returning cached responses from the completion function # checking cache print_verbose("INSIDE CHECKING CACHE") - if ( - litellm.cache is not None - and litellm.cache.supported_call_types is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ): - print_verbose("Checking Cache") - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key + caching_handler_response: CachingHandlerResponse = ( + _llm_caching_handler._sync_get_cache( + model=model or "", + original_function=original_function, + logging_obj=logging_obj, + start_time=start_time, + call_type=call_type, + kwargs=kwargs, + args=args, ) - cached_result = litellm.cache.get_cache(*args, **kwargs) - if cached_result is not None: - if "detail" in cached_result: - # implies an error occurred - pass - else: - call_type = original_function.__name__ - print_verbose( - f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}" - ) - if call_type == CallTypes.completion.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=ModelResponse(), - stream=kwargs.get("stream", False), - ) + ) + if caching_handler_response.cached_result is not None: + return caching_handler_response.cached_result - if kwargs.get("stream", False) is True: - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - elif call_type == CallTypes.embedding.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - response_type="embedding", - ) - elif call_type == CallTypes.rerank.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - response_type="rerank", - ) - # LOG SUCCESS - cache_hit = True - end_time = datetime.datetime.now() - ( - model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( - model=model or "", - custom_llm_provider=kwargs.get( - "custom_llm_provider", None - ), - api_base=kwargs.get("api_base", None), - api_key=kwargs.get("api_key", None), - ) - print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" - ) - logging_obj.update_environment_variables( - model=model, - user=kwargs.get("user", None), - optional_params={}, - litellm_params={ - "logger_fn": kwargs.get("logger_fn", None), - "acompletion": False, - "metadata": kwargs.get("metadata", {}), - "model_info": kwargs.get("model_info", {}), - "proxy_server_request": kwargs.get( - "proxy_server_request", None - ), - "preset_cache_key": kwargs.get( - "preset_cache_key", None - ), - "stream_response": kwargs.get( - "stream_response", {} - ), - }, - input=kwargs.get("messages", ""), - api_key=kwargs.get("api_key", None), - original_response=str(cached_result), - additional_args=None, - stream=kwargs.get("stream", False), - ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() - cache_key = kwargs.get("preset_cache_key", None) - if ( - isinstance(cached_result, BaseModel) - or isinstance(cached_result, CustomStreamWrapper) - ) and hasattr(cached_result, "_hidden_params"): - cached_result._hidden_params["cache_key"] = cache_key # type: ignore - return cached_result - else: - print_verbose( - "Cache Miss! on key - {}".format(preset_cache_key) - ) # CHECK MAX TOKENS if ( kwargs.get("max_tokens", None) is not None @@ -1245,30 +1151,13 @@ def client(original_function): isinstance(result, EmbeddingResponse) and _caching_handler_response.final_embedding_cached_response is not None - and _caching_handler_response.final_embedding_cached_response.data - is not None ): - idx = 0 - final_data_list = [] - for ( - item - ) in _caching_handler_response.final_embedding_cached_response.data: - if item is None and result.data is not None: - final_data_list.append(result.data[idx]) - idx += 1 - else: - final_data_list.append(item) - - _caching_handler_response.final_embedding_cached_response.data = ( - final_data_list + return _llm_caching_handler._combine_cached_embedding_response_with_api_result( + _caching_handler_response=_caching_handler_response, + embedding_response=result, + start_time=start_time, + end_time=end_time, ) - _caching_handler_response.final_embedding_cached_response._hidden_params[ - "cache_hit" - ] = True - _caching_handler_response.final_embedding_cached_response._response_ms = ( - end_time - start_time - ).total_seconds() * 1000 - return _caching_handler_response.final_embedding_cached_response return result except Exception as e: diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index b95b66681..b6ea36c38 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -1067,7 +1067,7 @@ async def test_redis_cache_acompletion_stream_bedrock(): response_1_content += chunk.choices[0].delta.content or "" print(response_1_content) - time.sleep(0.5) + await asyncio.sleep(1) print("\n\n Response 1 content: ", response_1_content, "\n\n") response2 = await litellm.acompletion( @@ -1082,8 +1082,8 @@ async def test_redis_cache_acompletion_stream_bedrock(): response_2_content += chunk.choices[0].delta.content or "" print(response_2_content) - print("\nresponse 1", response_1_content) - print("\nresponse 2", response_2_content) + print("\nfinal response 1", response_1_content) + print("\nfinal response 2", response_2_content) assert ( response_1_content == response_2_content ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"