mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(caching.py): fix async in-memory caching
This commit is contained in:
parent
7f83cca62c
commit
40c952f7c2
3 changed files with 16 additions and 42 deletions
|
@ -53,6 +53,13 @@ class InMemoryCache(BaseCache):
|
||||||
async def async_set_cache(self, key, value, **kwargs):
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
self.set_cache(key=key, value=value, **kwargs)
|
self.set_cache(key=key, value=value, **kwargs)
|
||||||
|
|
||||||
|
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||||
|
for cache_key, cache_value in cache_list:
|
||||||
|
if ttl is not None:
|
||||||
|
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
||||||
|
else:
|
||||||
|
self.set_cache(key=cache_key, value=cache_value)
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
if key in self.cache_dict:
|
if key in self.cache_dict:
|
||||||
if key in self.ttl_dict:
|
if key in self.ttl_dict:
|
||||||
|
@ -730,10 +737,10 @@ class Cache:
|
||||||
preset_cache_key = litellm.cache.get_cache_key(
|
preset_cache_key = litellm.cache.get_cache_key(
|
||||||
*args, **{**kwargs, "input": i}
|
*args, **{**kwargs, "input": i}
|
||||||
)
|
)
|
||||||
|
kwargs["cache_key"] = preset_cache_key
|
||||||
embedding_response = result.data[idx]
|
embedding_response = result.data[idx]
|
||||||
cache_key, cached_data = self._add_cache_logic(
|
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||||
result=embedding_response,
|
result=embedding_response,
|
||||||
cache_key=preset_cache_key,
|
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -279,55 +279,20 @@ async def test_embedding_caching_azure_individual_items():
|
||||||
litellm.cache = Cache()
|
litellm.cache = Cache()
|
||||||
common_msg = f"hey how's it going {uuid.uuid4()}"
|
common_msg = f"hey how's it going {uuid.uuid4()}"
|
||||||
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
|
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
|
||||||
|
embedding_1 = [common_msg]
|
||||||
embedding_2 = [
|
embedding_2 = [
|
||||||
common_msg,
|
common_msg,
|
||||||
f"I'm fine {uuid.uuid4()}",
|
f"I'm fine {uuid.uuid4()}",
|
||||||
common_msg,
|
]
|
||||||
common_msg,
|
|
||||||
common_msg,
|
|
||||||
] * 20
|
|
||||||
embedding_2 = [
|
|
||||||
common_msg,
|
|
||||||
f"I'm fine {uuid.uuid4()}",
|
|
||||||
common_msg,
|
|
||||||
common_msg,
|
|
||||||
common_msg,
|
|
||||||
] * 20
|
|
||||||
embedding_3 = [
|
|
||||||
common_msg_2,
|
|
||||||
common_msg_2,
|
|
||||||
common_msg_2,
|
|
||||||
common_msg_2,
|
|
||||||
f"I'm fine {uuid.uuid4()}",
|
|
||||||
] * 20 # make sure azure doesn't return cached 'i'm fine' responses
|
|
||||||
|
|
||||||
embedding_val_1 = await aembedding(
|
embedding_val_1 = await aembedding(
|
||||||
model="azure/azure-embedding-model", input=embedding_1, caching=True
|
model="azure/azure-embedding-model", input=embedding_1, caching=True
|
||||||
)
|
)
|
||||||
|
|
||||||
second_response_start_time = time.time()
|
|
||||||
embedding_val_2 = await aembedding(
|
embedding_val_2 = await aembedding(
|
||||||
model="azure/azure-embedding-model", input=embedding_2, caching=True
|
model="azure/azure-embedding-model", input=embedding_2, caching=True
|
||||||
)
|
)
|
||||||
if embedding_val_2 is not None:
|
print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}")
|
||||||
second_response_end_time = time.time()
|
assert embedding_val_2._hidden_params["cache_hit"] == True
|
||||||
second_response_time = second_response_end_time - second_response_start_time
|
|
||||||
|
|
||||||
third_response_start_time = time.time()
|
|
||||||
embedding_val_3 = await aembedding(
|
|
||||||
model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True}
|
|
||||||
)
|
|
||||||
if embedding_val_3 is not None:
|
|
||||||
third_response_end_time = time.time()
|
|
||||||
third_response_time = third_response_end_time - third_response_start_time
|
|
||||||
|
|
||||||
print(f"second_response_time: {second_response_time}")
|
|
||||||
print(f"third_response_time: {third_response_time}")
|
|
||||||
|
|
||||||
assert (
|
|
||||||
second_response_time < third_response_time - 0.5
|
|
||||||
) # make sure it's actually faster
|
|
||||||
raise Exception(f"it works {second_response_time} < {third_response_time}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -369,7 +334,6 @@ async def test_redis_cache_basic():
|
||||||
)
|
)
|
||||||
print(f"stored_val: {stored_val}")
|
print(f"stored_val: {stored_val}")
|
||||||
assert stored_val["id"] == response1.id
|
assert stored_val["id"] == response1.id
|
||||||
raise Exception("it worked!")
|
|
||||||
|
|
||||||
|
|
||||||
def test_redis_cache_completion():
|
def test_redis_cache_completion():
|
||||||
|
|
|
@ -2357,6 +2357,9 @@ def client(original_function):
|
||||||
model=kwargs.get("model"),
|
model=kwargs.get("model"),
|
||||||
data=[None] * len(original_kwargs_input),
|
data=[None] * len(original_kwargs_input),
|
||||||
)
|
)
|
||||||
|
final_embedding_cached_response._hidden_params[
|
||||||
|
"cache_hit"
|
||||||
|
] = True
|
||||||
|
|
||||||
for val in non_null_list:
|
for val in non_null_list:
|
||||||
idx, cr = val # (idx, cr) tuple
|
idx, cr = val # (idx, cr) tuple
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue