fix(litellm_logging.py): don't disable global callbacks when dynamic callbacks are set

Fixes issue where global callbacks - e.g. prometheus were overriden when langfuse was set dynamically
This commit is contained in:
Krrish Dholakia 2024-11-23 02:00:45 +05:30
parent dfb34dfe92
commit 541326731f
2 changed files with 96 additions and 58 deletions

View file

@ -934,19 +934,10 @@ class Logging:
status="success",
)
)
if self.dynamic_success_callbacks is not None and isinstance(
self.dynamic_success_callbacks, list
):
callbacks = self.dynamic_success_callbacks
## keep the internal functions ##
for callback in litellm.success_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm.success_callback
callbacks = get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_success_callbacks,
global_callbacks=litellm.success_callback,
)
## REDACT MESSAGES ##
result = redact_message_input_output_from_logging(
@ -1469,21 +1460,10 @@ class Logging:
status="success",
)
)
if self.dynamic_async_success_callbacks is not None and isinstance(
self.dynamic_async_success_callbacks, list
):
callbacks = self.dynamic_async_success_callbacks
## keep the internal functions ##
for callback in litellm._async_success_callback:
callback_name = ""
if isinstance(callback, CustomLogger):
callback_name = callback.__class__.__name__
if callable(callback):
callback_name = callback.__name__
if "_PROXY_" in callback_name:
callbacks.append(callback)
else:
callbacks = litellm._async_success_callback
callbacks = get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
global_callbacks=litellm._async_success_callback,
)
result = redact_message_input_output_from_logging(
model_call_details=(
@ -1750,21 +1730,10 @@ class Logging:
start_time=start_time,
end_time=end_time,
)
callbacks = [] # init this to empty incase it's not created
if self.dynamic_failure_callbacks is not None and isinstance(
self.dynamic_failure_callbacks, list
):
callbacks = self.dynamic_failure_callbacks
## keep the internal functions ##
for callback in litellm.failure_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm.failure_callback
callbacks = get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_failure_callbacks,
global_callbacks=litellm.failure_callback,
)
result = None # result sent to all loggers, init this to None incase it's not created
@ -1947,21 +1916,10 @@ class Logging:
end_time=end_time,
)
callbacks = [] # init this to empty incase it's not created
if self.dynamic_async_failure_callbacks is not None and isinstance(
self.dynamic_async_failure_callbacks, list
):
callbacks = self.dynamic_async_failure_callbacks
## keep the internal functions ##
for callback in litellm._async_failure_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm._async_failure_callback
callbacks = get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_failure_callbacks,
global_callbacks=litellm._async_failure_callback,
)
result = None # result sent to all loggers, init this to None incase it's not created
for callback in callbacks:
@ -2953,3 +2911,11 @@ def modify_integration(integration_name, integration_params):
if integration_name == "supabase":
if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"]
def get_combined_callback_list(
dynamic_success_callbacks: Optional[List], global_callbacks: List
) -> List:
if dynamic_success_callbacks is None:
return global_callbacks
return list(set(dynamic_success_callbacks + global_callbacks))

View file

@ -216,3 +216,75 @@ async def test_init_custom_logger_compatible_class_as_callback():
await use_callback_in_llm_call(callback, used_in="success_callback")
reset_env_vars()
def test_dynamic_logging_global_callback():
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import ModelResponse, Choices, Message, Usage
cl = CustomLogger()
litellm_logging = LiteLLMLoggingObj(
model="claude-3-opus-20240229",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="completion",
start_time=datetime.now(),
litellm_call_id="123",
function_id="456",
kwargs={
"langfuse_public_key": "my-mock-public-key",
"langfuse_secret_key": "my-mock-secret-key",
},
dynamic_success_callbacks=["langfuse"],
)
with patch.object(cl, "log_success_event") as mock_log_success_event:
cl.log_success_event = mock_log_success_event
litellm.success_callback = [cl]
try:
litellm_logging.success_handler(
result=ModelResponse(
id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70",
created=1732306261,
model="claude-3-opus-20240229",
object="chat.completion",
system_fingerprint=None,
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="hello",
role="assistant",
tool_calls=None,
function_call=None,
),
)
],
usage=Usage(
completion_tokens=20,
prompt_tokens=10,
total_tokens=30,
completion_tokens_details=None,
prompt_tokens_details=None,
),
),
start_time=datetime.now(),
end_time=datetime.now(),
cache_hit=False,
)
except Exception as e:
print(f"Error: {e}")
mock_log_success_event.assert_called_once()
def test_get_combined_callback_list():
from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list
assert get_combined_callback_list(
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
) == ["langfuse", "lago"]