fix(litellm_logging.py): fix condition check

Fixes https://github.com/BerriAI/litellm/issues/4633
This commit is contained in:
Krrish Dholakia 2024-07-12 09:22:19 -07:00
parent 88eb25da5c
commit f5b3cc6c02
2 changed files with 57 additions and 18 deletions

View file

@ -1105,19 +1105,19 @@ class Logging:
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False "acompletion", False
) )
== False is not True
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False "aembedding", False
) )
== False is not True
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False "aimage_generation", False
) )
== False is not True
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False "atranscription", False
) )
== False is not True
): ):
global openMeterLogger global openMeterLogger
if openMeterLogger is None: if openMeterLogger is None:
@ -1150,19 +1150,19 @@ class Logging:
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False "acompletion", False
) )
== False is not True
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False "aembedding", False
) )
== False is not True
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False "aimage_generation", False
) )
== False is not True
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False "atranscription", False
) )
== False is not True
): # custom logger class ): # custom logger class
if self.stream and complete_streaming_response is None: if self.stream and complete_streaming_response is None:
callback.log_stream_event( callback.log_stream_event(
@ -1190,19 +1190,19 @@ class Logging:
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False "acompletion", False
) )
== False is not True
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False "aembedding", False
) )
== False is not True
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False "aimage_generation", False
) )
== False is not True
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False "atranscription", False
) )
== False is not True
): # custom logger functions ): # custom logger functions
print_verbose( print_verbose(
f"success callbacks: Running Custom Callback Function" f"success callbacks: Running Custom Callback Function"
@ -1634,11 +1634,11 @@ class Logging:
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False "acompletion", False
) )
== False is not True
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False "aembedding", False
) )
== False is not True
): # custom logger class ): # custom logger class
callback.log_failure_event( callback.log_failure_event(

View file

@ -1,14 +1,22 @@
### What this tests #### ### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler ## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback import asyncio
import inspect
import os
import sys
import time
import traceback
import uuid
from datetime import datetime from datetime import datetime
import pytest, uuid
import pytest
from pydantic import BaseModel from pydantic import BaseModel
sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List, Union from typing import List, Literal, Optional, Union
from litellm import completion, embedding, Cache
import litellm import litellm
from litellm import Cache, completion, embedding
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import LiteLLMCommonStrings from litellm.types.utils import LiteLLMCommonStrings
@ -821,6 +829,37 @@ async def test_async_embedding_openai():
# asyncio.run(test_async_embedding_openai()) # asyncio.run(test_async_embedding_openai())
## Test Azure + Async
def test_sync_embedding():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = litellm.embedding(
model="azure/azure-embedding-model", input=["good morning from litellm"]
)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = litellm.embedding(
model="azure/azure-embedding-model",
input=["good morning from litellm"],
api_key="my-bad-key",
)
except:
pass
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 1
assert len(customHandler_failure.states) == 3 # pre, post, failure
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
## Test Azure + Async ## Test Azure + Async
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_embedding_azure(): async def test_async_embedding_azure():