litellm-mirror/tests/litellm/integrations/arize/test_arize_utils.py
Ishaan Jaff 6ed7261b4c feat: Added Missing Attributes For Arize & Phoenix Integration (#10043) (#10215)
* feat: Added Missing Attributes For Arize & Phoenix Integration

* chore: Added noqa for PLR0915 to suppress warning

* chore: Moved Contributor Test to Correct Location

* chore: Removed Redundant Fallback

Co-authored-by: Ali Saleh <saleh.a@turing.com>
2025-04-23 10:23:28 +00:00

231 lines
8.1 KiB
Python

import json
import os
import sys
from typing import Optional
# Adds the grandparent directory to sys.path to allow importing project modules
sys.path.insert(0, os.path.abspath("../.."))
import asyncio
import litellm
import pytest
from litellm.integrations.arize.arize import ArizeLogger
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations._types.open_inference import (
SpanAttributes,
MessageAttributes,
ToolCallAttributes,
)
from litellm.types.utils import Choices, StandardCallbackDynamicParams
def test_arize_set_attributes():
"""
Test setting attributes for Arize, including all custom LLM attributes.
Ensures that the correct span attributes are being added during a request.
"""
from unittest.mock import MagicMock
from litellm.types.utils import ModelResponse
span = MagicMock() # Mocked tracing span to test attribute setting
# Construct kwargs to simulate a real LLM request scenario
kwargs = {
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Basic Request Content"}],
"standard_logging_object": {
"model_parameters": {"user": "test_user"},
"metadata": {"key_1": "value_1", "key_2": None},
"call_type": "completion",
},
"optional_params": {
"max_tokens": "100",
"temperature": "1",
"top_p": "5",
"stream": False,
"user": "test_user",
"tools": [
{
"function": {
"name": "get_weather",
"description": "Fetches weather details.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name",
}
},
"required": ["location"],
},
}
}
],
"functions": [{"name": "get_weather"}, {"name": "get_stock_price"}],
},
"litellm_params": {"custom_llm_provider": "openai"},
}
# Simulated LLM response object
response_obj = ModelResponse(
usage={"total_tokens": 100, "completion_tokens": 60, "prompt_tokens": 40},
choices=[
Choices(message={"role": "assistant", "content": "Basic Response Content"})
],
model="gpt-4o",
id="chatcmpl-ID",
)
# Apply attribute setting via ArizeLogger
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
# Validate that the expected number of attributes were set
assert span.set_attribute.call_count == 28
# Metadata attached to the span
span.set_attribute.assert_any_call(
SpanAttributes.METADATA, json.dumps({"key_1": "value_1", "key_2": None})
)
# Basic LLM information
span.set_attribute.assert_any_call(SpanAttributes.LLM_MODEL_NAME, "gpt-4o")
span.set_attribute.assert_any_call("llm.request.type", "completion")
span.set_attribute.assert_any_call(SpanAttributes.LLM_PROVIDER, "openai")
# LLM generation parameters
span.set_attribute.assert_any_call("llm.request.max_tokens", "100")
span.set_attribute.assert_any_call("llm.request.temperature", "1")
span.set_attribute.assert_any_call("llm.request.top_p", "5")
# Streaming and user info
span.set_attribute.assert_any_call("llm.is_streaming", "False")
span.set_attribute.assert_any_call("llm.user", "test_user")
# Response metadata
span.set_attribute.assert_any_call("llm.response.id", "chatcmpl-ID")
span.set_attribute.assert_any_call("llm.response.model", "gpt-4o")
span.set_attribute.assert_any_call(SpanAttributes.OPENINFERENCE_SPAN_KIND, "LLM")
# Request message content and metadata
span.set_attribute.assert_any_call(
SpanAttributes.INPUT_VALUE, "Basic Request Content"
)
span.set_attribute.assert_any_call(
f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
"user",
)
span.set_attribute.assert_any_call(
f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
"Basic Request Content",
)
# Tool call definitions and function names
span.set_attribute.assert_any_call(
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_NAME}", "get_weather"
)
span.set_attribute.assert_any_call(
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_DESCRIPTION}",
"Fetches weather details.",
)
span.set_attribute.assert_any_call(
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_PARAMETERS}",
json.dumps(
{
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"],
}
),
)
# Tool calls captured from optional_params
span.set_attribute.assert_any_call(
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.0.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
"get_weather",
)
span.set_attribute.assert_any_call(
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.1.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
"get_stock_price",
)
# Invocation parameters
span.set_attribute.assert_any_call(
SpanAttributes.LLM_INVOCATION_PARAMETERS, '{"user": "test_user"}'
)
# User ID
span.set_attribute.assert_any_call(SpanAttributes.USER_ID, "test_user")
# Output message content
span.set_attribute.assert_any_call(
SpanAttributes.OUTPUT_VALUE, "Basic Response Content"
)
span.set_attribute.assert_any_call(
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
"assistant",
)
span.set_attribute.assert_any_call(
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
"Basic Response Content",
)
# Token counts
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, 100)
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 60)
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 40)
class TestArizeLogger(CustomLogger):
"""
Custom logger implementation to capture standard_callback_dynamic_params.
Used to verify that dynamic config keys are being passed to callbacks.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
# Capture dynamic params and print them for verification
print("logged kwargs", json.dumps(kwargs, indent=4, default=str))
self.standard_callback_dynamic_params = kwargs.get(
"standard_callback_dynamic_params"
)
@pytest.mark.asyncio
async def test_arize_dynamic_params():
"""
Test to ensure that dynamic Arize keys (API key and space key)
are received inside the callback logger at runtime.
"""
test_arize_logger = TestArizeLogger()
litellm.callbacks = [test_arize_logger]
# Perform a mocked async completion call to trigger logging
await litellm.acompletion(
model="gpt-4o",
messages=[{"role": "user", "content": "Basic Request Content"}],
mock_response="test",
arize_api_key="test_api_key_dynamic",
arize_space_key="test_space_key_dynamic",
)
# Allow for async propagation
await asyncio.sleep(2)
# Assert dynamic parameters were received in the callback
assert test_arize_logger.standard_callback_dynamic_params is not None
assert (
test_arize_logger.standard_callback_dynamic_params.get("arize_api_key")
== "test_api_key_dynamic"
)
assert (
test_arize_logger.standard_callback_dynamic_params.get("arize_space_key")
== "test_space_key_dynamic"
)