forked from phoenix/litellm-mirror
working test for init custom logger
This commit is contained in:
parent
24a3090ff6
commit
bd9e29b8b9
1 changed files with 16 additions and 1 deletions
|
@ -31,7 +31,7 @@ from litellm.integrations.opik.opik import OpikLogger
|
||||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||||
from litellm.integrations.argilla import ArgillaLogger
|
from litellm.integrations.argilla import ArgillaLogger
|
||||||
from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler
|
from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
callback_class_str_to_classType = {
|
callback_class_str_to_classType = {
|
||||||
"lago": LagoLogger,
|
"lago": LagoLogger,
|
||||||
|
@ -68,6 +68,7 @@ expected_env_vars = {
|
||||||
"LOGFIRE_TOKEN": "logfire_token",
|
"LOGFIRE_TOKEN": "logfire_token",
|
||||||
"ARIZE_SPACE_KEY": "arize_space_key",
|
"ARIZE_SPACE_KEY": "arize_space_key",
|
||||||
"ARIZE_API_KEY": "arize_api_key",
|
"ARIZE_API_KEY": "arize_api_key",
|
||||||
|
"ARGILLA_API_KEY": "argilla_api_key",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,6 +109,17 @@ async def use_callback_in_llm_call(
|
||||||
elif callback == "argilla":
|
elif callback == "argilla":
|
||||||
litellm.argilla_transformation_object = {}
|
litellm.argilla_transformation_object = {}
|
||||||
|
|
||||||
|
# Mock the httpx call for Argilla dataset retrieval
|
||||||
|
if callback == "argilla":
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
mock_response = httpx.Response(
|
||||||
|
status_code=200, json={"items": [{"id": "mocked_dataset_id"}]}
|
||||||
|
)
|
||||||
|
patch.object(
|
||||||
|
litellm.module_level_client, "get", return_value=mock_response
|
||||||
|
).start()
|
||||||
|
|
||||||
if used_in == "callbacks":
|
if used_in == "callbacks":
|
||||||
litellm.callbacks = [callback]
|
litellm.callbacks = [callback]
|
||||||
elif used_in == "success_callback":
|
elif used_in == "success_callback":
|
||||||
|
@ -136,6 +148,9 @@ async def use_callback_in_llm_call(
|
||||||
assert len(litellm.failure_callback) == 1
|
assert len(litellm.failure_callback) == 1
|
||||||
assert len(litellm.callbacks) == 1
|
assert len(litellm.callbacks) == 1
|
||||||
|
|
||||||
|
if callback == "argilla":
|
||||||
|
patch.stopall()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_init_custom_logger_compatible_class_as_callback():
|
async def test_init_custom_logger_compatible_class_as_callback():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue