diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 062f78a76..579fe6583 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -176,7 +176,8 @@ class CompletionCustomHandler( ) or isinstance(kwargs["input"], (dict, str)) assert isinstance(kwargs["api_key"], (str, type(None))) assert isinstance( - kwargs["original_response"], (str, litellm.CustomStreamWrapper) + kwargs["original_response"], + (str, litellm.CustomStreamWrapper, BaseModel), ) assert isinstance(kwargs["additional_args"], (dict, type(None))) assert isinstance(kwargs["log_event_type"], str)