fix(caching.py): fix async in-memory caching

This commit is contained in:
Krrish Dholakia 2024-01-13 15:33:57 +05:30
parent cdadac1649
commit 79cc739b53
3 changed files with 16 additions and 42 deletions

View file

@ -53,6 +53,13 @@ class InMemoryCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs) self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
else:
self.set_cache(key=cache_key, value=cache_value)
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
if key in self.cache_dict: if key in self.cache_dict:
if key in self.ttl_dict: if key in self.ttl_dict:
@ -730,10 +737,10 @@ class Cache:
preset_cache_key = litellm.cache.get_cache_key( preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i} *args, **{**kwargs, "input": i}
) )
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx] embedding_response = result.data[idx]
cache_key, cached_data = self._add_cache_logic( cache_key, cached_data, kwargs = self._add_cache_logic(
result=embedding_response, result=embedding_response,
cache_key=preset_cache_key,
*args, *args,
**kwargs, **kwargs,
) )

View file

@ -279,55 +279,20 @@ async def test_embedding_caching_azure_individual_items():
litellm.cache = Cache() litellm.cache = Cache()
common_msg = f"hey how's it going {uuid.uuid4()}" common_msg = f"hey how's it going {uuid.uuid4()}"
common_msg_2 = f"hey how's it going {uuid.uuid4()}" common_msg_2 = f"hey how's it going {uuid.uuid4()}"
embedding_1 = [common_msg]
embedding_2 = [ embedding_2 = [
common_msg, common_msg,
f"I'm fine {uuid.uuid4()}", f"I'm fine {uuid.uuid4()}",
common_msg, ]
common_msg,
common_msg,
] * 20
embedding_2 = [
common_msg,
f"I'm fine {uuid.uuid4()}",
common_msg,
common_msg,
common_msg,
] * 20
embedding_3 = [
common_msg_2,
common_msg_2,
common_msg_2,
common_msg_2,
f"I'm fine {uuid.uuid4()}",
] * 20 # make sure azure doesn't return cached 'i'm fine' responses
embedding_val_1 = await aembedding( embedding_val_1 = await aembedding(
model="azure/azure-embedding-model", input=embedding_1, caching=True model="azure/azure-embedding-model", input=embedding_1, caching=True
) )
second_response_start_time = time.time()
embedding_val_2 = await aembedding( embedding_val_2 = await aembedding(
model="azure/azure-embedding-model", input=embedding_2, caching=True model="azure/azure-embedding-model", input=embedding_2, caching=True
) )
if embedding_val_2 is not None: print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}")
second_response_end_time = time.time() assert embedding_val_2._hidden_params["cache_hit"] == True
second_response_time = second_response_end_time - second_response_start_time
third_response_start_time = time.time()
embedding_val_3 = await aembedding(
model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True}
)
if embedding_val_3 is not None:
third_response_end_time = time.time()
third_response_time = third_response_end_time - third_response_start_time
print(f"second_response_time: {second_response_time}")
print(f"third_response_time: {third_response_time}")
assert (
second_response_time < third_response_time - 0.5
) # make sure it's actually faster
raise Exception(f"it works {second_response_time} < {third_response_time}")
@pytest.mark.asyncio @pytest.mark.asyncio
@ -369,7 +334,6 @@ async def test_redis_cache_basic():
) )
print(f"stored_val: {stored_val}") print(f"stored_val: {stored_val}")
assert stored_val["id"] == response1.id assert stored_val["id"] == response1.id
raise Exception("it worked!")
def test_redis_cache_completion(): def test_redis_cache_completion():

View file

@ -2357,6 +2357,9 @@ def client(original_function):
model=kwargs.get("model"), model=kwargs.get("model"),
data=[None] * len(original_kwargs_input), data=[None] * len(original_kwargs_input),
) )
final_embedding_cached_response._hidden_params[
"cache_hit"
] = True
for val in non_null_list: for val in non_null_list:
idx, cr = val # (idx, cr) tuple idx, cr = val # (idx, cr) tuple