fix(utils.py): move adding custom logger callback to success event in… (#7905)

* fix(utils.py): move adding custom logger callback to success event into separate function + don't add success callback to failure event

if user is explicitly choosing 'success' callback, don't log failure as well

* test(test_utils.py): add unit test to ensure custom logger callback only adds callback to specific event

* fix(utils.py): remove string from list of callbacks once corresponding callback class is added

prevents floating values - simplifies testing

* fix(utils.py): fix linting error

* test: cleanup args before test

* test: fix test

* test: update test

* test: fix test
This commit is contained in:
Krish Dholakia 2025-01-22 21:49:09 -08:00 committed by GitHub
parent cefbada875
commit 4911cd80a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 98 additions and 49 deletions

View file

@ -50,10 +50,10 @@ if set_verbose == True:
_turn_on_debug()
###############################################
### Callbacks /Logging / Success / Failure Handlers #####
input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = []
input_callback: List[Union[str, Callable, CustomLogger]] = []
success_callback: List[Union[str, Callable, CustomLogger]] = []
failure_callback: List[Union[str, Callable, CustomLogger]] = []
service_callback: List[Union[str, Callable, CustomLogger]] = []
_custom_logger_compatible_callbacks_literal = Literal[
"lago",
"openmeter",
@ -90,13 +90,13 @@ langsmith_batch_size: Optional[int] = None
argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
argilla_transformation_object: Optional[Dict[str, Any]] = None
_async_input_callback: List[Callable] = (
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Union[str, Callable]] = (
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Callable] = (
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []

View file

@ -10,3 +10,6 @@ model_list:
api_base: http://0.0.0.0:8090
timeout: 2
num_retries: 0
litellm_settings:
success_callback: ["langfuse"]

View file

@ -318,9 +318,61 @@ def custom_llm_setup():
litellm._custom_providers.append(custom_llm["provider"])
def _add_custom_logger_callback_to_specific_event(
callback: str, logging_event: Literal["success", "failure"]
) -> None:
"""
Add a custom logger callback to the specific event
"""
from litellm import _custom_logger_compatible_callbacks_literal
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
if callback not in litellm._known_custom_logger_compatible_callbacks:
verbose_logger.debug(
f"Callback {callback} is not a valid custom logger compatible callback. Known list - {litellm._known_custom_logger_compatible_callbacks}"
)
return
callback_class = _init_custom_logger_compatible_class(
cast(_custom_logger_compatible_callbacks_literal, callback),
internal_usage_cache=None,
llm_router=None,
)
# don't double add a callback
if callback_class is not None and not any(
isinstance(cb, type(callback_class)) for cb in litellm.callbacks # type: ignore
):
if logging_event == "success":
litellm.success_callback.append(callback_class)
litellm._async_success_callback.append(callback_class)
if callback in litellm.success_callback:
litellm.success_callback.remove(
callback
) # remove the string from the callback list
if callback in litellm._async_success_callback:
litellm._async_success_callback.remove(
callback
) # remove the string from the callback list
elif logging_event == "failure":
litellm.failure_callback.append(callback_class)
litellm._async_failure_callback.append(callback_class)
if callback in litellm.failure_callback:
litellm.failure_callback.remove(
callback
) # remove the string from the callback list
if callback in litellm._async_failure_callback:
litellm._async_failure_callback.remove(
callback
) # remove the string from the callback list
def function_setup( # noqa: PLR0915
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
### NOTICES ###
from litellm import Logging as LiteLLMLogging
from litellm.litellm_core_utils.litellm_logging import set_callbacks
@ -401,27 +453,11 @@ def function_setup( # noqa: PLR0915
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
litellm._async_success_callback.append(callback)
removed_async_items.append(index)
elif callback in litellm._known_custom_logger_compatible_callbacks:
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
callback_class = _init_custom_logger_compatible_class(
callback, # type: ignore
internal_usage_cache=None,
llm_router=None, # type: ignore
)
# don't double add a callback
if callback_class is not None and not any(
isinstance(cb, type(callback_class)) for cb in litellm.callbacks
elif (
callback in litellm._known_custom_logger_compatible_callbacks
and isinstance(callback, str)
):
litellm.callbacks.append(callback_class) # type: ignore
litellm.input_callback.append(callback_class) # type: ignore
litellm.success_callback.append(callback_class) # type: ignore
litellm.failure_callback.append(callback_class) # type: ignore
litellm._async_success_callback.append(callback_class) # type: ignore
litellm._async_failure_callback.append(callback_class) # type: ignore
_add_custom_logger_callback_to_specific_event(callback, "success")
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):

View file

@ -1496,24 +1496,34 @@ def test_get_num_retries(num_retries):
)
@pytest.mark.parametrize("filter_invalid_headers", [True, False])
@pytest.mark.parametrize(
"custom_llm_provider, expected_result",
[("anthropic", {"anthropic-beta": "123"}), ("bedrock", {}), ("vertex_ai", {})],
)
def test_get_clean_extra_headers(
filter_invalid_headers, custom_llm_provider, expected_result, monkeypatch
):
from litellm.utils import get_clean_extra_headers
def test_add_custom_logger_callback_to_specific_event(monkeypatch):
from litellm.utils import _add_custom_logger_callback_to_specific_event
monkeypatch.setattr(litellm, "filter_invalid_headers", filter_invalid_headers)
monkeypatch.setattr(litellm, "success_callback", [])
monkeypatch.setattr(litellm, "failure_callback", [])
if filter_invalid_headers:
assert (
get_clean_extra_headers({"anthropic-beta": "123"}, custom_llm_provider)
== expected_result
_add_custom_logger_callback_to_specific_event("langfuse", "success")
assert len(litellm.success_callback) == 1
assert len(litellm.failure_callback) == 0
def test_add_custom_logger_callback_to_specific_event_e2e(monkeypatch):
monkeypatch.setattr(litellm, "success_callback", [])
monkeypatch.setattr(litellm, "failure_callback", [])
monkeypatch.setattr(litellm, "callbacks", [])
litellm.success_callback = ["humanloop"]
curr_len_success_callback = len(litellm.success_callback)
curr_len_failure_callback = len(litellm.failure_callback)
litellm.completion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response="Testing langfuse",
)
else:
assert get_clean_extra_headers(
{"anthropic-beta": "123"}, custom_llm_provider
) == {"anthropic-beta": "123"}
assert len(litellm.success_callback) == curr_len_success_callback
assert len(litellm.failure_callback) == curr_len_failure_callback

View file

@ -193,8 +193,8 @@ async def use_callback_in_llm_call(
elif used_in == "success_callback":
print(f"litellm.success_callback: {litellm.success_callback}")
print(f"litellm._async_success_callback: {litellm._async_success_callback}")
assert isinstance(litellm.success_callback[1], expected_class)
assert len(litellm.success_callback) == 2 # ["lago", LagoLogger]
assert isinstance(litellm.success_callback[0], expected_class)
assert len(litellm.success_callback) == 1 # ["lago", LagoLogger]
assert isinstance(litellm._async_success_callback[0], expected_class)
assert len(litellm._async_success_callback) == 1