From 79cc739b533c05726187fc8e6173fadd300de27d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 13 Jan 2024 15:33:57 +0530 Subject: [PATCH] fix(caching.py): fix async in-memory caching --- litellm/caching.py | 11 +++++++-- litellm/tests/test_caching.py | 44 ++++------------------------------- litellm/utils.py | 3 +++ 3 files changed, 16 insertions(+), 42 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 59fc0ab672..594310b319 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -53,6 +53,13 @@ class InMemoryCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): self.set_cache(key=key, value=value, **kwargs) + async def async_set_cache_pipeline(self, cache_list, ttl=None): + for cache_key, cache_value in cache_list: + if ttl is not None: + self.set_cache(key=cache_key, value=cache_value, ttl=ttl) + else: + self.set_cache(key=cache_key, value=cache_value) + def get_cache(self, key, **kwargs): if key in self.cache_dict: if key in self.ttl_dict: @@ -730,10 +737,10 @@ class Cache: preset_cache_key = litellm.cache.get_cache_key( *args, **{**kwargs, "input": i} ) + kwargs["cache_key"] = preset_cache_key embedding_response = result.data[idx] - cache_key, cached_data = self._add_cache_logic( + cache_key, cached_data, kwargs = self._add_cache_logic( result=embedding_response, - cache_key=preset_cache_key, *args, **kwargs, ) diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 695ad931a2..89410598e3 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -279,55 +279,20 @@ async def test_embedding_caching_azure_individual_items(): litellm.cache = Cache() common_msg = f"hey how's it going {uuid.uuid4()}" common_msg_2 = f"hey how's it going {uuid.uuid4()}" + embedding_1 = [common_msg] embedding_2 = [ common_msg, f"I'm fine {uuid.uuid4()}", - common_msg, - common_msg, - common_msg, - ] * 20 - embedding_2 = [ - common_msg, - f"I'm fine {uuid.uuid4()}", - common_msg, - common_msg, - common_msg, - ] * 20 - embedding_3 = [ - common_msg_2, - common_msg_2, - common_msg_2, - common_msg_2, - f"I'm fine {uuid.uuid4()}", - ] * 20 # make sure azure doesn't return cached 'i'm fine' responses + ] embedding_val_1 = await aembedding( model="azure/azure-embedding-model", input=embedding_1, caching=True ) - - second_response_start_time = time.time() embedding_val_2 = await aembedding( model="azure/azure-embedding-model", input=embedding_2, caching=True ) - if embedding_val_2 is not None: - second_response_end_time = time.time() - second_response_time = second_response_end_time - second_response_start_time - - third_response_start_time = time.time() - embedding_val_3 = await aembedding( - model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True} - ) - if embedding_val_3 is not None: - third_response_end_time = time.time() - third_response_time = third_response_end_time - third_response_start_time - - print(f"second_response_time: {second_response_time}") - print(f"third_response_time: {third_response_time}") - - assert ( - second_response_time < third_response_time - 0.5 - ) # make sure it's actually faster - raise Exception(f"it works {second_response_time} < {third_response_time}") + print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}") + assert embedding_val_2._hidden_params["cache_hit"] == True @pytest.mark.asyncio @@ -369,7 +334,6 @@ async def test_redis_cache_basic(): ) print(f"stored_val: {stored_val}") assert stored_val["id"] == response1.id - raise Exception("it worked!") def test_redis_cache_completion(): diff --git a/litellm/utils.py b/litellm/utils.py index 3fee13937a..344917118d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2357,6 +2357,9 @@ def client(original_function): model=kwargs.get("model"), data=[None] * len(original_kwargs_input), ) + final_embedding_cached_response._hidden_params[ + "cache_hit" + ] = True for val in non_null_list: idx, cr = val # (idx, cr) tuple