From bfa26dd5b34ce8eb1b9865ba07433df471dadadc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 11 Jan 2024 18:14:22 +0530 Subject: [PATCH] fix(utils.py): bug fixes --- litellm/tests/test_custom_callback_input.py | 7 +- litellm/utils.py | 90 +++++++++++++++++++-- 2 files changed, 88 insertions(+), 9 deletions(-) diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 0fb69b645..d9364d11e 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -545,8 +545,9 @@ async def test_async_chat_bedrock_stream(): # asyncio.run(test_async_chat_bedrock_stream()) -# Text Completion - +# Text Completion + + ## Test OpenAI text completion + Async @pytest.mark.asyncio async def test_async_text_completion_openai_stream(): @@ -585,6 +586,7 @@ async def test_async_text_completion_openai_stream(): except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # EMBEDDING ## Test OpenAI + Async @pytest.mark.asyncio @@ -758,6 +760,7 @@ async def test_async_embedding_azure_caching(): ) await asyncio.sleep(1) # success callbacks are done in parallel print(customHandler_caching.states) + print(customHandler_caching.errors) assert len(customHandler_caching.errors) == 0 assert len(customHandler_caching.states) == 4 # pre, post, success, success diff --git a/litellm/utils.py b/litellm/utils.py index 49bb47420..e4c67c0e8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2026,12 +2026,17 @@ def client(original_function): ) # if caching is false or cache["no-cache"]==True, don't run this if ( - (kwargs.get("caching", None) is None and litellm.cache is not None) - or kwargs.get("caching", False) == True - or ( - kwargs.get("cache", None) is not None - and kwargs.get("cache", {}).get("no-cache", False) != True + ( + (kwargs.get("caching", None) is None and litellm.cache is not None) + or kwargs.get("caching", False) == True + or ( + kwargs.get("cache", None) is not None + and kwargs.get("cache", {}).get("no-cache", False) != True + ) ) + and kwargs.get("aembedding", False) != True + and kwargs.get("acompletion", False) != True + and kwargs.get("aimg_generation", False) != True ): # allow users to control returning cached responses from the completion function # checking cache print_verbose(f"INSIDE CHECKING CACHE") @@ -2329,13 +2334,78 @@ def client(original_function): if len(non_null_list) > 0: final_embedding_cached_response = EmbeddingResponse( - model=kwargs.get("model"), data=[] + model=kwargs.get("model"), + data=[None] * len(original_kwargs_input), ) for val in non_null_list: idx, cr = val # (idx, cr) tuple if cr is not None: final_embedding_cached_response.data[idx] = val + + if len(remaining_list) == 0: + # LOG SUCCESS + cache_hit = True + end_time = datetime.datetime.now() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get( + "custom_llm_provider", None + ), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params={ + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": True, + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get( + "proxy_server_request", None + ), + "preset_cache_key": kwargs.get( + "preset_cache_key", None + ), + "stream_response": kwargs.get( + "stream_response", {} + ), + }, + input=kwargs.get("messages", ""), + api_key=kwargs.get("api_key", None), + original_response=str(final_embedding_cached_response), + additional_args=None, + stream=kwargs.get("stream", False), + ) + asyncio.create_task( + logging_obj.async_success_handler( + final_embedding_cached_response, + start_time, + end_time, + cache_hit, + ) + ) + threading.Thread( + target=logging_obj.success_handler, + args=( + final_embedding_cached_response, + start_time, + end_time, + cache_hit, + ), + ).start() + return final_embedding_cached_response # MODEL CALL result = await original_function(*args, **kwargs) end_time = datetime.datetime.now() @@ -2371,6 +2441,7 @@ def client(original_function): embedding_kwargs = copy.deepcopy(kwargs) for idx, i in enumerate(kwargs["input"]): embedding_response = result.data[idx] + embedding_kwargs["input"] = i asyncio.create_task( litellm.cache._async_add_cache( embedding_response, *args, **embedding_kwargs @@ -5971,7 +6042,12 @@ def exception_type( message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response, + response=httpx.Response( + status_code=500, + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), + ), ) elif original_exception.status_code == 401: exception_mapping_worked = True