working test for init custom logger

This commit is contained in:
Ishaan Jaff 2024-10-21 14:33:52 +05:30
parent 24a3090ff6
commit bd9e29b8b9

View file

@ -31,7 +31,7 @@ from litellm.integrations.opik.opik import OpikLogger
from litellm.integrations.opentelemetry import OpenTelemetry from litellm.integrations.opentelemetry import OpenTelemetry
from litellm.integrations.argilla import ArgillaLogger from litellm.integrations.argilla import ArgillaLogger
from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler
from unittest.mock import patch
callback_class_str_to_classType = { callback_class_str_to_classType = {
"lago": LagoLogger, "lago": LagoLogger,
@ -68,6 +68,7 @@ expected_env_vars = {
"LOGFIRE_TOKEN": "logfire_token", "LOGFIRE_TOKEN": "logfire_token",
"ARIZE_SPACE_KEY": "arize_space_key", "ARIZE_SPACE_KEY": "arize_space_key",
"ARIZE_API_KEY": "arize_api_key", "ARIZE_API_KEY": "arize_api_key",
"ARGILLA_API_KEY": "argilla_api_key",
} }
@ -108,6 +109,17 @@ async def use_callback_in_llm_call(
elif callback == "argilla": elif callback == "argilla":
litellm.argilla_transformation_object = {} litellm.argilla_transformation_object = {}
# Mock the httpx call for Argilla dataset retrieval
if callback == "argilla":
import httpx
mock_response = httpx.Response(
status_code=200, json={"items": [{"id": "mocked_dataset_id"}]}
)
patch.object(
litellm.module_level_client, "get", return_value=mock_response
).start()
if used_in == "callbacks": if used_in == "callbacks":
litellm.callbacks = [callback] litellm.callbacks = [callback]
elif used_in == "success_callback": elif used_in == "success_callback":
@ -136,6 +148,9 @@ async def use_callback_in_llm_call(
assert len(litellm.failure_callback) == 1 assert len(litellm.failure_callback) == 1
assert len(litellm.callbacks) == 1 assert len(litellm.callbacks) == 1
if callback == "argilla":
patch.stopall()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_init_custom_logger_compatible_class_as_callback(): async def test_init_custom_logger_compatible_class_as_callback():