diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 532fac5d79..556628d828 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -152,11 +152,11 @@ class CompletionCustomHandler( ## RESPONSE OBJECT assert isinstance( response_obj, - Union[ + ( litellm.ModelResponse, litellm.EmbeddingResponse, litellm.ImageResponse, - ], + ), ) ## KWARGS assert isinstance(kwargs["model"], str)