mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
111 lines
4.1 KiB
Python
111 lines
4.1 KiB
Python
import os
|
|
import sys
|
|
import time
|
|
from unittest.mock import Mock, patch
|
|
import json
|
|
import opentelemetry.exporter.otlp.proto.grpc.trace_exporter
|
|
from typing import Optional
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system-path
|
|
from litellm.integrations._types.open_inference import SpanAttributes
|
|
from litellm.integrations.arize.arize import ArizeConfig, ArizeLogger
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.main import completion
|
|
import litellm
|
|
from litellm.types.utils import Choices, StandardCallbackDynamicParams
|
|
import pytest
|
|
import asyncio
|
|
|
|
|
|
def test_arize_set_attributes():
|
|
"""
|
|
Test setting attributes for Arize
|
|
"""
|
|
from unittest.mock import MagicMock
|
|
from litellm.types.utils import ModelResponse
|
|
|
|
span = MagicMock()
|
|
kwargs = {
|
|
"role": "user",
|
|
"content": "simple arize test",
|
|
"model": "gpt-4o",
|
|
"messages": [{"role": "user", "content": "basic arize test"}],
|
|
"standard_logging_object": {
|
|
"model_parameters": {"user": "test_user"},
|
|
"metadata": {"key": "value", "key2": None},
|
|
},
|
|
}
|
|
response_obj = ModelResponse(
|
|
usage={"total_tokens": 100, "completion_tokens": 60, "prompt_tokens": 40},
|
|
choices=[Choices(message={"role": "assistant", "content": "response content"})],
|
|
)
|
|
|
|
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
|
|
|
assert span.set_attribute.call_count == 14
|
|
span.set_attribute.assert_any_call(
|
|
SpanAttributes.METADATA, json.dumps({"key": "value", "key2": None})
|
|
)
|
|
span.set_attribute.assert_any_call(SpanAttributes.LLM_MODEL_NAME, "gpt-4o")
|
|
span.set_attribute.assert_any_call(SpanAttributes.OPENINFERENCE_SPAN_KIND, "LLM")
|
|
span.set_attribute.assert_any_call(SpanAttributes.INPUT_VALUE, "basic arize test")
|
|
span.set_attribute.assert_any_call("llm.input_messages.0.message.role", "user")
|
|
span.set_attribute.assert_any_call(
|
|
"llm.input_messages.0.message.content", "basic arize test"
|
|
)
|
|
span.set_attribute.assert_any_call(
|
|
SpanAttributes.LLM_INVOCATION_PARAMETERS, '{"user": "test_user"}'
|
|
)
|
|
span.set_attribute.assert_any_call(SpanAttributes.USER_ID, "test_user")
|
|
span.set_attribute.assert_any_call(SpanAttributes.OUTPUT_VALUE, "response content")
|
|
span.set_attribute.assert_any_call(
|
|
"llm.output_messages.0.message.role", "assistant"
|
|
)
|
|
span.set_attribute.assert_any_call(
|
|
"llm.output_messages.0.message.content", "response content"
|
|
)
|
|
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, 100)
|
|
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 60)
|
|
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 40)
|
|
|
|
|
|
class TestArizeLogger(CustomLogger):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.standard_callback_dynamic_params: Optional[
|
|
StandardCallbackDynamicParams
|
|
] = None
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
print("logged kwargs", json.dumps(kwargs, indent=4, default=str))
|
|
self.standard_callback_dynamic_params = kwargs.get(
|
|
"standard_callback_dynamic_params"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_arize_dynamic_params():
|
|
"""verify arize ai dynamic params are recieved by a callback"""
|
|
test_arize_logger = TestArizeLogger()
|
|
litellm.callbacks = [test_arize_logger]
|
|
await litellm.acompletion(
|
|
model="gpt-4o",
|
|
messages=[{"role": "user", "content": "basic arize test"}],
|
|
mock_response="test",
|
|
arize_api_key="test_api_key_dynamic",
|
|
arize_space_key="test_space_key_dynamic",
|
|
)
|
|
|
|
await asyncio.sleep(2)
|
|
|
|
assert test_arize_logger.standard_callback_dynamic_params is not None
|
|
assert (
|
|
test_arize_logger.standard_callback_dynamic_params.get("arize_api_key")
|
|
== "test_api_key_dynamic"
|
|
)
|
|
assert (
|
|
test_arize_logger.standard_callback_dynamic_params.get("arize_space_key")
|
|
== "test_space_key_dynamic"
|
|
)
|