fix(utils.py): support caching individual items in embedding input list

https://github.com/BerriAI/litellm/issues/1350
This commit is contained in:
Krrish Dholakia 2024-01-11 16:51:34 +05:30
parent df9df7b040
commit 2cd5f0fbe9
2 changed files with 78 additions and 6 deletions

View file

@ -2174,6 +2174,7 @@ def client(original_function):
result = None
logging_obj = kwargs.get("litellm_logging_obj", None)
# only set litellm_call_id if its not in kwargs
call_type = original_function.__name__
if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4())
try:
@ -2204,6 +2205,7 @@ def client(original_function):
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
)
# if caching is false, don't run this
final_embedding_cached_response = None
if (
(kwargs.get("caching", None) is None and litellm.cache is not None)
or kwargs.get("caching", False) == True
@ -2220,8 +2222,24 @@ def client(original_function):
in litellm.cache.supported_call_types
):
print_verbose(f"Checking Cache")
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result != None:
if call_type == CallTypes.aembedding.value and isinstance(
kwargs["input"], list
):
tasks = []
embedding_kwargs = copy.deepcopy(kwargs)
for idx, i in enumerate(kwargs["input"]):
embedding_kwargs["input"] = i
tasks.append(
litellm.cache._async_get_cache(
*args, **embedding_kwargs
)
)
cached_result = await asyncio.gather(*tasks)
else:
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None and not isinstance(
cached_result, list
):
print_verbose(f"Cache Hit!")
call_type = original_function.__name__
if call_type == CallTypes.acompletion.value and isinstance(
@ -2294,6 +2312,30 @@ def client(original_function):
args=(cached_result, start_time, end_time, cache_hit),
).start()
return cached_result
elif (
call_type == CallTypes.aembedding.value
and cached_result is not None
and isinstance(cached_result, list)
):
remaining_list = []
non_null_list = []
for idx, cr in enumerate(cached_result):
if cr is None:
remaining_list.append(kwargs["input"][idx])
else:
non_null_list.append((idx, cr))
original_kwargs_input = kwargs["input"]
kwargs["input"] = remaining_list
if len(non_null_list) > 0:
final_embedding_cached_response = EmbeddingResponse(
model=kwargs.get("model"), data=[]
)
for val in non_null_list:
idx, cr = val # (idx, cr) tuple
if cr is not None:
final_embedding_cached_response.data[idx] = val
# MODEL CALL
result = await original_function(*args, **kwargs)
end_time = datetime.datetime.now()
@ -2323,9 +2365,23 @@ def client(original_function):
if isinstance(result, litellm.ModelResponse) or isinstance(
result, litellm.EmbeddingResponse
):
asyncio.create_task(
litellm.cache._async_add_cache(result.json(), *args, **kwargs)
)
if isinstance(result, EmbeddingResponse) and isinstance(
kwargs["input"], list
):
embedding_kwargs = copy.deepcopy(kwargs)
for idx, i in enumerate(kwargs["input"]):
embedding_response = result.data[idx]
asyncio.create_task(
litellm.cache._async_add_cache(
embedding_response, *args, **embedding_kwargs
)
)
else:
asyncio.create_task(
litellm.cache._async_add_cache(
result.json(), *args, **kwargs
)
)
else:
asyncio.create_task(
litellm.cache._async_add_cache(result, *args, **kwargs)
@ -2349,6 +2405,22 @@ def client(original_function):
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
elif (
isinstance(result, EmbeddingResponse)
and final_embedding_cached_response is not None
):
idx = 0
final_data_list = []
for item in final_embedding_cached_response.data:
if item is None:
final_data_list.append(result.data[idx])
else:
final_data_list.append(item)
idx += 1
final_embedding_cached_response.data = final_data_list
return final_embedding_cached_response
return result
except Exception as e:
traceback_exception = traceback.format_exc()