fix: support async redis caching

This commit is contained in:
Krrish Dholakia 2024-01-12 21:46:41 +05:30
parent 817a3d29b7
commit 007870390d
6 changed files with 357 additions and 122 deletions

View file

@ -2214,8 +2214,13 @@ def client(original_function):
)
# if caching is false, don't run this
final_embedding_cached_response = None
if (
(kwargs.get("caching", None) is None and litellm.cache is not None)
(
kwargs.get("caching", None) is None
and kwargs.get("cache", None) is None
and litellm.cache is not None
)
or kwargs.get("caching", False) == True
or (
kwargs.get("cache", None) is not None
@ -2234,12 +2239,13 @@ def client(original_function):
kwargs["input"], list
):
tasks = []
embedding_kwargs = copy.deepcopy(kwargs)
for idx, i in enumerate(kwargs["input"]):
embedding_kwargs["input"] = i
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
tasks.append(
litellm.cache._async_get_cache(
*args, **embedding_kwargs
litellm.cache.async_get_cache(
cache_key=preset_cache_key
)
)
cached_result = await asyncio.gather(*tasks)
@ -2445,24 +2451,28 @@ def client(original_function):
if isinstance(result, EmbeddingResponse) and isinstance(
kwargs["input"], list
):
embedding_kwargs = copy.deepcopy(kwargs)
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
)
embedding_response = result.data[idx]
embedding_kwargs["input"] = i
asyncio.create_task(
litellm.cache._async_add_cache(
embedding_response, *args, **embedding_kwargs
litellm.cache.async_add_cache(
embedding_response,
*args,
cache_key=preset_cache_key,
)
)
# pass
else:
asyncio.create_task(
litellm.cache._async_add_cache(
litellm.cache.async_add_cache(
result.json(), *args, **kwargs
)
)
else:
asyncio.create_task(
litellm.cache._async_add_cache(result, *args, **kwargs)
litellm.cache.async_add_cache(result, *args, **kwargs)
)
# LOG SUCCESS - handle streaming success logging in the _next_ object
print_verbose(