forked from phoenix/litellm-mirror
fix(utils.py): support caching individual items in embedding input list
https://github.com/BerriAI/litellm/issues/1350
This commit is contained in:
parent
df9df7b040
commit
2cd5f0fbe9
2 changed files with 78 additions and 6 deletions
|
@ -234,7 +234,7 @@ async def acompletion(
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
# Use a partial function to pass your keyword arguments
|
# Use a partial function to pass your keyword arguments
|
||||||
func = partial(completion, **completion_kwargs, **kwargs)
|
func = partial(completion, **completion_kwargs)
|
||||||
|
|
||||||
# Add the context to the function
|
# Add the context to the function
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
|
|
|
@ -2174,6 +2174,7 @@ def client(original_function):
|
||||||
result = None
|
result = None
|
||||||
logging_obj = kwargs.get("litellm_logging_obj", None)
|
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||||
# only set litellm_call_id if its not in kwargs
|
# only set litellm_call_id if its not in kwargs
|
||||||
|
call_type = original_function.__name__
|
||||||
if "litellm_call_id" not in kwargs:
|
if "litellm_call_id" not in kwargs:
|
||||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||||
try:
|
try:
|
||||||
|
@ -2204,6 +2205,7 @@ def client(original_function):
|
||||||
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
|
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
|
||||||
)
|
)
|
||||||
# if caching is false, don't run this
|
# if caching is false, don't run this
|
||||||
|
final_embedding_cached_response = None
|
||||||
if (
|
if (
|
||||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||||
or kwargs.get("caching", False) == True
|
or kwargs.get("caching", False) == True
|
||||||
|
@ -2220,8 +2222,24 @@ def client(original_function):
|
||||||
in litellm.cache.supported_call_types
|
in litellm.cache.supported_call_types
|
||||||
):
|
):
|
||||||
print_verbose(f"Checking Cache")
|
print_verbose(f"Checking Cache")
|
||||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
if call_type == CallTypes.aembedding.value and isinstance(
|
||||||
if cached_result != None:
|
kwargs["input"], list
|
||||||
|
):
|
||||||
|
tasks = []
|
||||||
|
embedding_kwargs = copy.deepcopy(kwargs)
|
||||||
|
for idx, i in enumerate(kwargs["input"]):
|
||||||
|
embedding_kwargs["input"] = i
|
||||||
|
tasks.append(
|
||||||
|
litellm.cache._async_get_cache(
|
||||||
|
*args, **embedding_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cached_result = await asyncio.gather(*tasks)
|
||||||
|
else:
|
||||||
|
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||||
|
if cached_result is not None and not isinstance(
|
||||||
|
cached_result, list
|
||||||
|
):
|
||||||
print_verbose(f"Cache Hit!")
|
print_verbose(f"Cache Hit!")
|
||||||
call_type = original_function.__name__
|
call_type = original_function.__name__
|
||||||
if call_type == CallTypes.acompletion.value and isinstance(
|
if call_type == CallTypes.acompletion.value and isinstance(
|
||||||
|
@ -2294,6 +2312,30 @@ def client(original_function):
|
||||||
args=(cached_result, start_time, end_time, cache_hit),
|
args=(cached_result, start_time, end_time, cache_hit),
|
||||||
).start()
|
).start()
|
||||||
return cached_result
|
return cached_result
|
||||||
|
elif (
|
||||||
|
call_type == CallTypes.aembedding.value
|
||||||
|
and cached_result is not None
|
||||||
|
and isinstance(cached_result, list)
|
||||||
|
):
|
||||||
|
remaining_list = []
|
||||||
|
non_null_list = []
|
||||||
|
for idx, cr in enumerate(cached_result):
|
||||||
|
if cr is None:
|
||||||
|
remaining_list.append(kwargs["input"][idx])
|
||||||
|
else:
|
||||||
|
non_null_list.append((idx, cr))
|
||||||
|
original_kwargs_input = kwargs["input"]
|
||||||
|
kwargs["input"] = remaining_list
|
||||||
|
|
||||||
|
if len(non_null_list) > 0:
|
||||||
|
final_embedding_cached_response = EmbeddingResponse(
|
||||||
|
model=kwargs.get("model"), data=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
for val in non_null_list:
|
||||||
|
idx, cr = val # (idx, cr) tuple
|
||||||
|
if cr is not None:
|
||||||
|
final_embedding_cached_response.data[idx] = val
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
result = await original_function(*args, **kwargs)
|
result = await original_function(*args, **kwargs)
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
|
@ -2323,9 +2365,23 @@ def client(original_function):
|
||||||
if isinstance(result, litellm.ModelResponse) or isinstance(
|
if isinstance(result, litellm.ModelResponse) or isinstance(
|
||||||
result, litellm.EmbeddingResponse
|
result, litellm.EmbeddingResponse
|
||||||
):
|
):
|
||||||
asyncio.create_task(
|
if isinstance(result, EmbeddingResponse) and isinstance(
|
||||||
litellm.cache._async_add_cache(result.json(), *args, **kwargs)
|
kwargs["input"], list
|
||||||
)
|
):
|
||||||
|
embedding_kwargs = copy.deepcopy(kwargs)
|
||||||
|
for idx, i in enumerate(kwargs["input"]):
|
||||||
|
embedding_response = result.data[idx]
|
||||||
|
asyncio.create_task(
|
||||||
|
litellm.cache._async_add_cache(
|
||||||
|
embedding_response, *args, **embedding_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
asyncio.create_task(
|
||||||
|
litellm.cache._async_add_cache(
|
||||||
|
result.json(), *args, **kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
litellm.cache._async_add_cache(result, *args, **kwargs)
|
litellm.cache._async_add_cache(result, *args, **kwargs)
|
||||||
|
@ -2349,6 +2405,22 @@ def client(original_function):
|
||||||
result._response_ms = (
|
result._response_ms = (
|
||||||
end_time - start_time
|
end_time - start_time
|
||||||
).total_seconds() * 1000 # return response latency in ms like openai
|
).total_seconds() * 1000 # return response latency in ms like openai
|
||||||
|
elif (
|
||||||
|
isinstance(result, EmbeddingResponse)
|
||||||
|
and final_embedding_cached_response is not None
|
||||||
|
):
|
||||||
|
idx = 0
|
||||||
|
final_data_list = []
|
||||||
|
for item in final_embedding_cached_response.data:
|
||||||
|
if item is None:
|
||||||
|
final_data_list.append(result.data[idx])
|
||||||
|
else:
|
||||||
|
final_data_list.append(item)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
final_embedding_cached_response.data = final_data_list
|
||||||
|
return final_embedding_cached_response
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback_exception = traceback.format_exc()
|
traceback_exception = traceback.format_exc()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue