From bd9e29b8b912925a9100c026568b8935fbf7d9e5 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 21 Oct 2024 14:33:52 +0530 Subject: [PATCH] working test for init custom logger --- .../test_unit_tests_init_callbacks.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index ee9746b79..418f1f71b 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -31,7 +31,7 @@ from litellm.integrations.opik.opik import OpikLogger from litellm.integrations.opentelemetry import OpenTelemetry from litellm.integrations.argilla import ArgillaLogger from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler - +from unittest.mock import patch callback_class_str_to_classType = { "lago": LagoLogger, @@ -68,6 +68,7 @@ expected_env_vars = { "LOGFIRE_TOKEN": "logfire_token", "ARIZE_SPACE_KEY": "arize_space_key", "ARIZE_API_KEY": "arize_api_key", + "ARGILLA_API_KEY": "argilla_api_key", } @@ -108,6 +109,17 @@ async def use_callback_in_llm_call( elif callback == "argilla": litellm.argilla_transformation_object = {} + # Mock the httpx call for Argilla dataset retrieval + if callback == "argilla": + import httpx + + mock_response = httpx.Response( + status_code=200, json={"items": [{"id": "mocked_dataset_id"}]} + ) + patch.object( + litellm.module_level_client, "get", return_value=mock_response + ).start() + if used_in == "callbacks": litellm.callbacks = [callback] elif used_in == "success_callback": @@ -136,6 +148,9 @@ async def use_callback_in_llm_call( assert len(litellm.failure_callback) == 1 assert len(litellm.callbacks) == 1 + if callback == "argilla": + patch.stopall() + @pytest.mark.asyncio async def test_init_custom_logger_compatible_class_as_callback():