fix(litellm_logging.py): fix condition check

Fixes https://github.com/BerriAI/litellm/issues/4633
This commit is contained in:
Krrish Dholakia 2024-07-12 09:22:19 -07:00
parent 88eb25da5c
commit f5b3cc6c02
2 changed files with 57 additions and 18 deletions

View file

@ -1105,19 +1105,19 @@ class Logging:
and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False
)
== False
is not True
and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False
)
== False
is not True
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
is not True
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
is not True
):
global openMeterLogger
if openMeterLogger is None:
@ -1150,19 +1150,19 @@ class Logging:
and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False
)
== False
is not True
and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False
)
== False
is not True
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
is not True
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
is not True
): # custom logger class
if self.stream and complete_streaming_response is None:
callback.log_stream_event(
@ -1190,19 +1190,19 @@ class Logging:
and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False
)
== False
is not True
and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False
)
== False
is not True
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
is not True
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
is not True
): # custom logger functions
print_verbose(
f"success callbacks: Running Custom Callback Function"
@ -1634,11 +1634,11 @@ class Logging:
and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False
)
== False
is not True
and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False
)
== False
is not True
): # custom logger class
callback.log_failure_event(

View file

@ -1,14 +1,22 @@
### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
import asyncio
import inspect
import os
import sys
import time
import traceback
import uuid
from datetime import datetime
import pytest, uuid
import pytest
from pydantic import BaseModel
sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List, Union
from litellm import completion, embedding, Cache
from typing import List, Literal, Optional, Union
import litellm
from litellm import Cache, completion, embedding
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import LiteLLMCommonStrings
@ -821,6 +829,37 @@ async def test_async_embedding_openai():
# asyncio.run(test_async_embedding_openai())
## Test Azure + Async
def test_sync_embedding():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = litellm.embedding(
model="azure/azure-embedding-model", input=["good morning from litellm"]
)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = litellm.embedding(
model="azure/azure-embedding-model",
input=["good morning from litellm"],
api_key="my-bad-key",
)
except:
pass
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 1
assert len(customHandler_failure.states) == 3 # pre, post, failure
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
## Test Azure + Async
@pytest.mark.asyncio
async def test_async_embedding_azure():