(testing) add test coverage for init custom logger class (#6341)

* working test for init custom logger

* add test coverage for custom_logger_compatible_class_as_callback
This commit is contained in:
Ishaan Jaff 2024-10-21 15:56:32 +05:30 committed by GitHub
parent bd9e29b8b9
commit d1f457d17a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -108,6 +108,20 @@ async def use_callback_in_llm_call(
return return
elif callback == "argilla": elif callback == "argilla":
litellm.argilla_transformation_object = {} litellm.argilla_transformation_object = {}
elif callback == "openmeter":
# it's currently handled in jank way, TODO: fix openmete and then actually run it's test
return
# Mock the httpx call for Argilla dataset retrieval
if callback == "argilla":
import httpx
mock_response = httpx.Response(
status_code=200, json={"items": [{"id": "mocked_dataset_id"}]}
)
patch.object(
litellm.module_level_client, "get", return_value=mock_response
).start()
# Mock the httpx call for Argilla dataset retrieval # Mock the httpx call for Argilla dataset retrieval
if callback == "argilla": if callback == "argilla":
@ -125,7 +139,7 @@ async def use_callback_in_llm_call(
elif used_in == "success_callback": elif used_in == "success_callback":
litellm.success_callback = [callback] litellm.success_callback = [callback]
for _ in range(1): for _ in range(5):
await litellm.acompletion( await litellm.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}], messages=[{"role": "user", "content": "hi"}],
@ -137,6 +151,7 @@ async def use_callback_in_llm_call(
expected_class = callback_class_str_to_classType[callback] expected_class = callback_class_str_to_classType[callback]
if used_in == "callbacks":
assert isinstance(litellm._async_success_callback[0], expected_class) assert isinstance(litellm._async_success_callback[0], expected_class)
assert isinstance(litellm._async_failure_callback[0], expected_class) assert isinstance(litellm._async_failure_callback[0], expected_class)
assert isinstance(litellm.success_callback[0], expected_class) assert isinstance(litellm.success_callback[0], expected_class)
@ -147,6 +162,20 @@ async def use_callback_in_llm_call(
assert len(litellm.success_callback) == 1 assert len(litellm.success_callback) == 1
assert len(litellm.failure_callback) == 1 assert len(litellm.failure_callback) == 1
assert len(litellm.callbacks) == 1 assert len(litellm.callbacks) == 1
elif used_in == "success_callback":
print(f"litellm.success_callback: {litellm.success_callback}")
print(f"litellm._async_success_callback: {litellm._async_success_callback}")
assert isinstance(litellm.success_callback[1], expected_class)
assert len(litellm.success_callback) == 2 # ["lago", LagoLogger]
assert isinstance(litellm._async_success_callback[0], expected_class)
assert len(litellm._async_success_callback) == 1
# TODO also assert that it's not set for failure_callback
# As of Oct 21 2024, it's currently set
# 1st hoping to add test coverage for just setting in success_callback/_async_success_callback
if callback == "argilla":
patch.stopall()
if callback == "argilla": if callback == "argilla":
patch.stopall() patch.stopall()
@ -156,10 +185,18 @@ async def use_callback_in_llm_call(
async def test_init_custom_logger_compatible_class_as_callback(): async def test_init_custom_logger_compatible_class_as_callback():
init_env_vars() init_env_vars()
# used like litellm.callbacks = ["prometheus"]
for callback in litellm._known_custom_logger_compatible_callbacks: for callback in litellm._known_custom_logger_compatible_callbacks:
print(f"Testing callback: {callback}") print(f"Testing callback: {callback}")
reset_all_callbacks() reset_all_callbacks()
await use_callback_in_llm_call(callback, used_in="callbacks") await use_callback_in_llm_call(callback, used_in="callbacks")
# used like this litellm.success_callback = ["prometheus"]
for callback in litellm._known_custom_logger_compatible_callbacks:
print(f"Testing callback: {callback}")
reset_all_callbacks()
await use_callback_in_llm_call(callback, used_in="success_callback")
reset_env_vars() reset_env_vars()