From 2cd5f0fbe92517d310e3f80ef8dfd7ff9b71b408 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jan 2024 16:51:34 +0530 Subject: [PATCH] fix(utils.py): support caching individual items in embedding input list https://github.com/BerriAI/litellm/issues/1350 --- litellm/main.py | 2 +- litellm/utils.py | 82 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index cb67774bf..3ec82ed0a 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -234,7 +234,7 @@ async def acompletion( } try: # Use a partial function to pass your keyword arguments - func = partial(completion, **completion_kwargs, **kwargs) + func = partial(completion, **completion_kwargs) # Add the context to the function ctx = contextvars.copy_context() diff --git a/litellm/utils.py b/litellm/utils.py index fcf6e9dea..49bb47420 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2174,6 +2174,7 @@ def client(original_function): result = None logging_obj = kwargs.get("litellm_logging_obj", None) # only set litellm_call_id if its not in kwargs + call_type = original_function.__name__ if "litellm_call_id" not in kwargs: kwargs["litellm_call_id"] = str(uuid.uuid4()) try: @@ -2204,6 +2205,7 @@ def client(original_function): f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" ) # if caching is false, don't run this + final_embedding_cached_response = None if ( (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True @@ -2220,8 +2222,24 @@ def client(original_function): in litellm.cache.supported_call_types ): print_verbose(f"Checking Cache") - cached_result = litellm.cache.get_cache(*args, **kwargs) - if cached_result != None: + if call_type == CallTypes.aembedding.value and isinstance( + kwargs["input"], list + ): + tasks = [] + embedding_kwargs = copy.deepcopy(kwargs) + for idx, i in enumerate(kwargs["input"]): + embedding_kwargs["input"] = i + tasks.append( + litellm.cache._async_get_cache( + *args, **embedding_kwargs + ) + ) + cached_result = await asyncio.gather(*tasks) + else: + cached_result = litellm.cache.get_cache(*args, **kwargs) + if cached_result is not None and not isinstance( + cached_result, list + ): print_verbose(f"Cache Hit!") call_type = original_function.__name__ if call_type == CallTypes.acompletion.value and isinstance( @@ -2294,6 +2312,30 @@ def client(original_function): args=(cached_result, start_time, end_time, cache_hit), ).start() return cached_result + elif ( + call_type == CallTypes.aembedding.value + and cached_result is not None + and isinstance(cached_result, list) + ): + remaining_list = [] + non_null_list = [] + for idx, cr in enumerate(cached_result): + if cr is None: + remaining_list.append(kwargs["input"][idx]) + else: + non_null_list.append((idx, cr)) + original_kwargs_input = kwargs["input"] + kwargs["input"] = remaining_list + + if len(non_null_list) > 0: + final_embedding_cached_response = EmbeddingResponse( + model=kwargs.get("model"), data=[] + ) + + for val in non_null_list: + idx, cr = val # (idx, cr) tuple + if cr is not None: + final_embedding_cached_response.data[idx] = val # MODEL CALL result = await original_function(*args, **kwargs) end_time = datetime.datetime.now() @@ -2323,9 +2365,23 @@ def client(original_function): if isinstance(result, litellm.ModelResponse) or isinstance( result, litellm.EmbeddingResponse ): - asyncio.create_task( - litellm.cache._async_add_cache(result.json(), *args, **kwargs) - ) + if isinstance(result, EmbeddingResponse) and isinstance( + kwargs["input"], list + ): + embedding_kwargs = copy.deepcopy(kwargs) + for idx, i in enumerate(kwargs["input"]): + embedding_response = result.data[idx] + asyncio.create_task( + litellm.cache._async_add_cache( + embedding_response, *args, **embedding_kwargs + ) + ) + else: + asyncio.create_task( + litellm.cache._async_add_cache( + result.json(), *args, **kwargs + ) + ) else: asyncio.create_task( litellm.cache._async_add_cache(result, *args, **kwargs) @@ -2349,6 +2405,22 @@ def client(original_function): result._response_ms = ( end_time - start_time ).total_seconds() * 1000 # return response latency in ms like openai + elif ( + isinstance(result, EmbeddingResponse) + and final_embedding_cached_response is not None + ): + idx = 0 + final_data_list = [] + for item in final_embedding_cached_response.data: + if item is None: + final_data_list.append(result.data[idx]) + else: + final_data_list.append(item) + idx += 1 + + final_embedding_cached_response.data = final_data_list + return final_embedding_cached_response + return result except Exception as e: traceback_exception = traceback.format_exc()