diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 8e84b5453..e97be428a 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1105,19 +1105,19 @@ class Logging: and self.model_call_details.get("litellm_params", {}).get( "acompletion", False ) - == False + is not True and self.model_call_details.get("litellm_params", {}).get( "aembedding", False ) - == False + is not True and self.model_call_details.get("litellm_params", {}).get( "aimage_generation", False ) - == False + is not True and self.model_call_details.get("litellm_params", {}).get( "atranscription", False ) - == False + is not True ): global openMeterLogger if openMeterLogger is None: @@ -1150,19 +1150,19 @@ class Logging: and self.model_call_details.get("litellm_params", {}).get( "acompletion", False ) - == False + is not True and self.model_call_details.get("litellm_params", {}).get( "aembedding", False ) - == False + is not True and self.model_call_details.get("litellm_params", {}).get( "aimage_generation", False ) - == False + is not True and self.model_call_details.get("litellm_params", {}).get( "atranscription", False ) - == False + is not True ): # custom logger class if self.stream and complete_streaming_response is None: callback.log_stream_event( @@ -1190,19 +1190,19 @@ class Logging: and self.model_call_details.get("litellm_params", {}).get( "acompletion", False ) - == False + is not True and self.model_call_details.get("litellm_params", {}).get( "aembedding", False ) - == False + is not True and self.model_call_details.get("litellm_params", {}).get( "aimage_generation", False ) - == False + is not True and self.model_call_details.get("litellm_params", {}).get( "atranscription", False ) - == False + is not True ): # custom logger functions print_verbose( f"success callbacks: Running Custom Callback Function" @@ -1634,11 +1634,11 @@ class Logging: and self.model_call_details.get("litellm_params", {}).get( "acompletion", False ) - == False + is not True and self.model_call_details.get("litellm_params", {}).get( "aembedding", False ) - == False + is not True ): # custom logger class callback.log_failure_event( diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 13f1d39aa..55edb8a79 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -1,14 +1,22 @@ ### What this tests #### ## This test asserts the type of data passed into each method of the custom callback handler -import sys, os, time, inspect, asyncio, traceback +import asyncio +import inspect +import os +import sys +import time +import traceback +import uuid from datetime import datetime -import pytest, uuid + +import pytest from pydantic import BaseModel sys.path.insert(0, os.path.abspath("../..")) -from typing import Optional, Literal, List, Union -from litellm import completion, embedding, Cache +from typing import List, Literal, Optional, Union + import litellm +from litellm import Cache, completion, embedding from litellm.integrations.custom_logger import CustomLogger from litellm.types.utils import LiteLLMCommonStrings @@ -821,6 +829,37 @@ async def test_async_embedding_openai(): # asyncio.run(test_async_embedding_openai()) +## Test Azure + Async +def test_sync_embedding(): + try: + customHandler_success = CompletionCustomHandler() + customHandler_failure = CompletionCustomHandler() + litellm.callbacks = [customHandler_success] + response = litellm.embedding( + model="azure/azure-embedding-model", input=["good morning from litellm"] + ) + print(f"customHandler_success.errors: {customHandler_success.errors}") + print(f"customHandler_success.states: {customHandler_success.states}") + assert len(customHandler_success.errors) == 0 + assert len(customHandler_success.states) == 3 # pre, post, success + # test failure callback + litellm.callbacks = [customHandler_failure] + try: + response = litellm.embedding( + model="azure/azure-embedding-model", + input=["good morning from litellm"], + api_key="my-bad-key", + ) + except: + pass + print(f"customHandler_failure.errors: {customHandler_failure.errors}") + print(f"customHandler_failure.states: {customHandler_failure.states}") + assert len(customHandler_failure.errors) == 1 + assert len(customHandler_failure.states) == 3 # pre, post, failure + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") + + ## Test Azure + Async @pytest.mark.asyncio async def test_async_embedding_azure():