mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
* feat: Added Missing Attributes For Arize & Phoenix Integration * chore: Added noqa for PLR0915 to suppress warning * chore: Moved Contributor Test to Correct Location * chore: Removed Redundant Fallback Co-authored-by: Ali Saleh <saleh.a@turing.com>
231 lines
8.1 KiB
Python
231 lines
8.1 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
from typing import Optional
|
|
|
|
# Adds the grandparent directory to sys.path to allow importing project modules
|
|
sys.path.insert(0, os.path.abspath("../.."))
|
|
|
|
import asyncio
|
|
import litellm
|
|
import pytest
|
|
from litellm.integrations.arize.arize import ArizeLogger
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.integrations._types.open_inference import (
|
|
SpanAttributes,
|
|
MessageAttributes,
|
|
ToolCallAttributes,
|
|
)
|
|
from litellm.types.utils import Choices, StandardCallbackDynamicParams
|
|
|
|
|
|
def test_arize_set_attributes():
|
|
"""
|
|
Test setting attributes for Arize, including all custom LLM attributes.
|
|
Ensures that the correct span attributes are being added during a request.
|
|
"""
|
|
from unittest.mock import MagicMock
|
|
from litellm.types.utils import ModelResponse
|
|
|
|
span = MagicMock() # Mocked tracing span to test attribute setting
|
|
|
|
# Construct kwargs to simulate a real LLM request scenario
|
|
kwargs = {
|
|
"model": "gpt-4o",
|
|
"messages": [{"role": "user", "content": "Basic Request Content"}],
|
|
"standard_logging_object": {
|
|
"model_parameters": {"user": "test_user"},
|
|
"metadata": {"key_1": "value_1", "key_2": None},
|
|
"call_type": "completion",
|
|
},
|
|
"optional_params": {
|
|
"max_tokens": "100",
|
|
"temperature": "1",
|
|
"top_p": "5",
|
|
"stream": False,
|
|
"user": "test_user",
|
|
"tools": [
|
|
{
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Fetches weather details.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {
|
|
"type": "string",
|
|
"description": "City name",
|
|
}
|
|
},
|
|
"required": ["location"],
|
|
},
|
|
}
|
|
}
|
|
],
|
|
"functions": [{"name": "get_weather"}, {"name": "get_stock_price"}],
|
|
},
|
|
"litellm_params": {"custom_llm_provider": "openai"},
|
|
}
|
|
|
|
# Simulated LLM response object
|
|
response_obj = ModelResponse(
|
|
usage={"total_tokens": 100, "completion_tokens": 60, "prompt_tokens": 40},
|
|
choices=[
|
|
Choices(message={"role": "assistant", "content": "Basic Response Content"})
|
|
],
|
|
model="gpt-4o",
|
|
id="chatcmpl-ID",
|
|
)
|
|
|
|
# Apply attribute setting via ArizeLogger
|
|
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
|
|
|
# Validate that the expected number of attributes were set
|
|
assert span.set_attribute.call_count == 28
|
|
|
|
# Metadata attached to the span
|
|
span.set_attribute.assert_any_call(
|
|
SpanAttributes.METADATA, json.dumps({"key_1": "value_1", "key_2": None})
|
|
)
|
|
|
|
# Basic LLM information
|
|
span.set_attribute.assert_any_call(SpanAttributes.LLM_MODEL_NAME, "gpt-4o")
|
|
span.set_attribute.assert_any_call("llm.request.type", "completion")
|
|
span.set_attribute.assert_any_call(SpanAttributes.LLM_PROVIDER, "openai")
|
|
|
|
# LLM generation parameters
|
|
span.set_attribute.assert_any_call("llm.request.max_tokens", "100")
|
|
span.set_attribute.assert_any_call("llm.request.temperature", "1")
|
|
span.set_attribute.assert_any_call("llm.request.top_p", "5")
|
|
|
|
# Streaming and user info
|
|
span.set_attribute.assert_any_call("llm.is_streaming", "False")
|
|
span.set_attribute.assert_any_call("llm.user", "test_user")
|
|
|
|
# Response metadata
|
|
span.set_attribute.assert_any_call("llm.response.id", "chatcmpl-ID")
|
|
span.set_attribute.assert_any_call("llm.response.model", "gpt-4o")
|
|
span.set_attribute.assert_any_call(SpanAttributes.OPENINFERENCE_SPAN_KIND, "LLM")
|
|
|
|
# Request message content and metadata
|
|
span.set_attribute.assert_any_call(
|
|
SpanAttributes.INPUT_VALUE, "Basic Request Content"
|
|
)
|
|
span.set_attribute.assert_any_call(
|
|
f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
|
|
"user",
|
|
)
|
|
span.set_attribute.assert_any_call(
|
|
f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
|
|
"Basic Request Content",
|
|
)
|
|
|
|
# Tool call definitions and function names
|
|
span.set_attribute.assert_any_call(
|
|
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_NAME}", "get_weather"
|
|
)
|
|
span.set_attribute.assert_any_call(
|
|
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_DESCRIPTION}",
|
|
"Fetches weather details.",
|
|
)
|
|
span.set_attribute.assert_any_call(
|
|
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_PARAMETERS}",
|
|
json.dumps(
|
|
{
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {"type": "string", "description": "City name"}
|
|
},
|
|
"required": ["location"],
|
|
}
|
|
),
|
|
)
|
|
|
|
# Tool calls captured from optional_params
|
|
span.set_attribute.assert_any_call(
|
|
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.0.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
|
|
"get_weather",
|
|
)
|
|
span.set_attribute.assert_any_call(
|
|
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.1.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
|
|
"get_stock_price",
|
|
)
|
|
|
|
# Invocation parameters
|
|
span.set_attribute.assert_any_call(
|
|
SpanAttributes.LLM_INVOCATION_PARAMETERS, '{"user": "test_user"}'
|
|
)
|
|
|
|
# User ID
|
|
span.set_attribute.assert_any_call(SpanAttributes.USER_ID, "test_user")
|
|
|
|
# Output message content
|
|
span.set_attribute.assert_any_call(
|
|
SpanAttributes.OUTPUT_VALUE, "Basic Response Content"
|
|
)
|
|
span.set_attribute.assert_any_call(
|
|
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
|
|
"assistant",
|
|
)
|
|
span.set_attribute.assert_any_call(
|
|
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
|
|
"Basic Response Content",
|
|
)
|
|
|
|
# Token counts
|
|
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, 100)
|
|
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 60)
|
|
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 40)
|
|
|
|
|
|
class TestArizeLogger(CustomLogger):
|
|
"""
|
|
Custom logger implementation to capture standard_callback_dynamic_params.
|
|
Used to verify that dynamic config keys are being passed to callbacks.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.standard_callback_dynamic_params: Optional[
|
|
StandardCallbackDynamicParams
|
|
] = None
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
# Capture dynamic params and print them for verification
|
|
print("logged kwargs", json.dumps(kwargs, indent=4, default=str))
|
|
self.standard_callback_dynamic_params = kwargs.get(
|
|
"standard_callback_dynamic_params"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_arize_dynamic_params():
|
|
"""
|
|
Test to ensure that dynamic Arize keys (API key and space key)
|
|
are received inside the callback logger at runtime.
|
|
"""
|
|
test_arize_logger = TestArizeLogger()
|
|
litellm.callbacks = [test_arize_logger]
|
|
|
|
# Perform a mocked async completion call to trigger logging
|
|
await litellm.acompletion(
|
|
model="gpt-4o",
|
|
messages=[{"role": "user", "content": "Basic Request Content"}],
|
|
mock_response="test",
|
|
arize_api_key="test_api_key_dynamic",
|
|
arize_space_key="test_space_key_dynamic",
|
|
)
|
|
|
|
# Allow for async propagation
|
|
await asyncio.sleep(2)
|
|
|
|
# Assert dynamic parameters were received in the callback
|
|
assert test_arize_logger.standard_callback_dynamic_params is not None
|
|
assert (
|
|
test_arize_logger.standard_callback_dynamic_params.get("arize_api_key")
|
|
== "test_api_key_dynamic"
|
|
)
|
|
assert (
|
|
test_arize_logger.standard_callback_dynamic_params.get("arize_space_key")
|
|
== "test_space_key_dynamic"
|
|
)
|