mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(utils.py): support caching individual items in embedding input list
https://github.com/BerriAI/litellm/issues/1350
This commit is contained in:
parent
df9df7b040
commit
2cd5f0fbe9
2 changed files with 78 additions and 6 deletions
|
@ -2174,6 +2174,7 @@ def client(original_function):
|
|||
result = None
|
||||
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||
# only set litellm_call_id if its not in kwargs
|
||||
call_type = original_function.__name__
|
||||
if "litellm_call_id" not in kwargs:
|
||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||
try:
|
||||
|
@ -2204,6 +2205,7 @@ def client(original_function):
|
|||
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
|
||||
)
|
||||
# if caching is false, don't run this
|
||||
final_embedding_cached_response = None
|
||||
if (
|
||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||
or kwargs.get("caching", False) == True
|
||||
|
@ -2220,8 +2222,24 @@ def client(original_function):
|
|||
in litellm.cache.supported_call_types
|
||||
):
|
||||
print_verbose(f"Checking Cache")
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
if cached_result != None:
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
kwargs["input"], list
|
||||
):
|
||||
tasks = []
|
||||
embedding_kwargs = copy.deepcopy(kwargs)
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
embedding_kwargs["input"] = i
|
||||
tasks.append(
|
||||
litellm.cache._async_get_cache(
|
||||
*args, **embedding_kwargs
|
||||
)
|
||||
)
|
||||
cached_result = await asyncio.gather(*tasks)
|
||||
else:
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
if cached_result is not None and not isinstance(
|
||||
cached_result, list
|
||||
):
|
||||
print_verbose(f"Cache Hit!")
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(
|
||||
|
@ -2294,6 +2312,30 @@ def client(original_function):
|
|||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
return cached_result
|
||||
elif (
|
||||
call_type == CallTypes.aembedding.value
|
||||
and cached_result is not None
|
||||
and isinstance(cached_result, list)
|
||||
):
|
||||
remaining_list = []
|
||||
non_null_list = []
|
||||
for idx, cr in enumerate(cached_result):
|
||||
if cr is None:
|
||||
remaining_list.append(kwargs["input"][idx])
|
||||
else:
|
||||
non_null_list.append((idx, cr))
|
||||
original_kwargs_input = kwargs["input"]
|
||||
kwargs["input"] = remaining_list
|
||||
|
||||
if len(non_null_list) > 0:
|
||||
final_embedding_cached_response = EmbeddingResponse(
|
||||
model=kwargs.get("model"), data=[]
|
||||
)
|
||||
|
||||
for val in non_null_list:
|
||||
idx, cr = val # (idx, cr) tuple
|
||||
if cr is not None:
|
||||
final_embedding_cached_response.data[idx] = val
|
||||
# MODEL CALL
|
||||
result = await original_function(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
|
@ -2323,9 +2365,23 @@ def client(original_function):
|
|||
if isinstance(result, litellm.ModelResponse) or isinstance(
|
||||
result, litellm.EmbeddingResponse
|
||||
):
|
||||
asyncio.create_task(
|
||||
litellm.cache._async_add_cache(result.json(), *args, **kwargs)
|
||||
)
|
||||
if isinstance(result, EmbeddingResponse) and isinstance(
|
||||
kwargs["input"], list
|
||||
):
|
||||
embedding_kwargs = copy.deepcopy(kwargs)
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
embedding_response = result.data[idx]
|
||||
asyncio.create_task(
|
||||
litellm.cache._async_add_cache(
|
||||
embedding_response, *args, **embedding_kwargs
|
||||
)
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache._async_add_cache(
|
||||
result.json(), *args, **kwargs
|
||||
)
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache._async_add_cache(result, *args, **kwargs)
|
||||
|
@ -2349,6 +2405,22 @@ def client(original_function):
|
|||
result._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000 # return response latency in ms like openai
|
||||
elif (
|
||||
isinstance(result, EmbeddingResponse)
|
||||
and final_embedding_cached_response is not None
|
||||
):
|
||||
idx = 0
|
||||
final_data_list = []
|
||||
for item in final_embedding_cached_response.data:
|
||||
if item is None:
|
||||
final_data_list.append(result.data[idx])
|
||||
else:
|
||||
final_data_list.append(item)
|
||||
idx += 1
|
||||
|
||||
final_embedding_cached_response.data = final_data_list
|
||||
return final_embedding_cached_response
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback_exception = traceback.format_exc()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue