mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
* feat: Added Missing Attributes For Arize & Phoenix Integration * chore: Added noqa for PLR0915 to suppress warning * chore: Moved Contributor Test to Correct Location * chore: Removed Redundant Fallback Co-authored-by: Ali Saleh <saleh.a@turing.com>
This commit is contained in:
parent
5f98d4d7de
commit
96e31d205c
4 changed files with 540 additions and 156 deletions
|
@ -45,6 +45,14 @@ class SpanAttributes:
|
|||
"""
|
||||
The name of the model being used.
|
||||
"""
|
||||
LLM_PROVIDER = "llm.provider"
|
||||
"""
|
||||
The provider of the model, such as OpenAI, Azure, Google, etc.
|
||||
"""
|
||||
LLM_SYSTEM = "llm.system"
|
||||
"""
|
||||
The AI product as identified by the client or server
|
||||
"""
|
||||
LLM_PROMPTS = "llm.prompts"
|
||||
"""
|
||||
Prompts provided to a completions API.
|
||||
|
@ -65,15 +73,40 @@ class SpanAttributes:
|
|||
"""
|
||||
Number of tokens in the prompt.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE = "llm.token_count.prompt_details.cache_write"
|
||||
"""
|
||||
Number of tokens in the prompt that were written to cache.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = "llm.token_count.prompt_details.cache_read"
|
||||
"""
|
||||
Number of tokens in the prompt that were read from cache.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO = "llm.token_count.prompt_details.audio"
|
||||
"""
|
||||
The number of audio input tokens presented in the prompt
|
||||
"""
|
||||
LLM_TOKEN_COUNT_COMPLETION = "llm.token_count.completion"
|
||||
"""
|
||||
Number of tokens in the completion.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = "llm.token_count.completion_details.reasoning"
|
||||
"""
|
||||
Number of tokens used for reasoning steps in the completion.
|
||||
"""
|
||||
LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO = "llm.token_count.completion_details.audio"
|
||||
"""
|
||||
The number of audio input tokens generated by the model
|
||||
"""
|
||||
LLM_TOKEN_COUNT_TOTAL = "llm.token_count.total"
|
||||
"""
|
||||
Total number of tokens, including both prompt and completion.
|
||||
"""
|
||||
|
||||
LLM_TOOLS = "llm.tools"
|
||||
"""
|
||||
List of tools that are advertised to the LLM to be able to call
|
||||
"""
|
||||
|
||||
TOOL_NAME = "tool.name"
|
||||
"""
|
||||
Name of the tool being used.
|
||||
|
@ -112,6 +145,19 @@ class SpanAttributes:
|
|||
The id of the user
|
||||
"""
|
||||
|
||||
PROMPT_VENDOR = "prompt.vendor"
|
||||
"""
|
||||
The vendor or origin of the prompt, e.g. a prompt library, a specialized service, etc.
|
||||
"""
|
||||
PROMPT_ID = "prompt.id"
|
||||
"""
|
||||
A vendor-specific id used to locate the prompt.
|
||||
"""
|
||||
PROMPT_URL = "prompt.url"
|
||||
"""
|
||||
A vendor-specific url used to locate the prompt.
|
||||
"""
|
||||
|
||||
|
||||
class MessageAttributes:
|
||||
"""
|
||||
|
@ -151,6 +197,10 @@ class MessageAttributes:
|
|||
The JSON string representing the arguments passed to the function
|
||||
during a function call.
|
||||
"""
|
||||
MESSAGE_TOOL_CALL_ID = "message.tool_call_id"
|
||||
"""
|
||||
The id of the tool call.
|
||||
"""
|
||||
|
||||
|
||||
class MessageContentAttributes:
|
||||
|
@ -186,6 +236,25 @@ class ImageAttributes:
|
|||
"""
|
||||
|
||||
|
||||
class AudioAttributes:
|
||||
"""
|
||||
Attributes for audio
|
||||
"""
|
||||
|
||||
AUDIO_URL = "audio.url"
|
||||
"""
|
||||
The url to an audio file
|
||||
"""
|
||||
AUDIO_MIME_TYPE = "audio.mime_type"
|
||||
"""
|
||||
The mime type of the audio file
|
||||
"""
|
||||
AUDIO_TRANSCRIPT = "audio.transcript"
|
||||
"""
|
||||
The transcript of the audio file
|
||||
"""
|
||||
|
||||
|
||||
class DocumentAttributes:
|
||||
"""
|
||||
Attributes for a document.
|
||||
|
@ -257,6 +326,10 @@ class ToolCallAttributes:
|
|||
Attributes for a tool call
|
||||
"""
|
||||
|
||||
TOOL_CALL_ID = "tool_call.id"
|
||||
"""
|
||||
The id of the tool call.
|
||||
"""
|
||||
TOOL_CALL_FUNCTION_NAME = "tool_call.function.name"
|
||||
"""
|
||||
The name of function that is being called during a tool call.
|
||||
|
@ -268,6 +341,18 @@ class ToolCallAttributes:
|
|||
"""
|
||||
|
||||
|
||||
class ToolAttributes:
|
||||
"""
|
||||
Attributes for a tools
|
||||
"""
|
||||
|
||||
TOOL_JSON_SCHEMA = "tool.json_schema"
|
||||
"""
|
||||
The json schema of a tool input, It is RECOMMENDED that this be in the
|
||||
OpenAI tool calling format: https://platform.openai.com/docs/assistants/tools
|
||||
"""
|
||||
|
||||
|
||||
class OpenInferenceSpanKindValues(Enum):
|
||||
TOOL = "TOOL"
|
||||
CHAIN = "CHAIN"
|
||||
|
@ -284,3 +369,21 @@ class OpenInferenceSpanKindValues(Enum):
|
|||
class OpenInferenceMimeTypeValues(Enum):
|
||||
TEXT = "text/plain"
|
||||
JSON = "application/json"
|
||||
|
||||
|
||||
class OpenInferenceLLMSystemValues(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
COHERE = "cohere"
|
||||
MISTRALAI = "mistralai"
|
||||
VERTEXAI = "vertexai"
|
||||
|
||||
|
||||
class OpenInferenceLLMProviderValues(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
COHERE = "cohere"
|
||||
MISTRALAI = "mistralai"
|
||||
GOOGLE = "google"
|
||||
AZURE = "azure"
|
||||
AWS = "aws"
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
@ -12,36 +13,141 @@ else:
|
|||
Span = Any
|
||||
|
||||
|
||||
def set_attributes(span: Span, kwargs, response_obj):
|
||||
def cast_as_primitive_value_type(value) -> Union[str, bool, int, float]:
|
||||
"""
|
||||
Converts a value to an OTEL-supported primitive for Arize/Phoenix observability.
|
||||
"""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, (str, bool, int, float)):
|
||||
return value
|
||||
try:
|
||||
return str(value)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def safe_set_attribute(span: Span, key: str, value: Any):
|
||||
"""
|
||||
Sets a span attribute safely with OTEL-compliant primitive typing for Arize/Phoenix.
|
||||
"""
|
||||
primitive_value = cast_as_primitive_value_type(value)
|
||||
span.set_attribute(key, primitive_value)
|
||||
|
||||
|
||||
def set_attributes(span: Span, kwargs, response_obj): # noqa: PLR0915
|
||||
"""
|
||||
Populates span with OpenInference-compliant LLM attributes for Arize and Phoenix tracing.
|
||||
"""
|
||||
from litellm.integrations._types.open_inference import (
|
||||
MessageAttributes,
|
||||
OpenInferenceSpanKindValues,
|
||||
SpanAttributes,
|
||||
ToolCallAttributes,
|
||||
)
|
||||
|
||||
try:
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
|
||||
#############################################
|
||||
############ LLM CALL METADATA ##############
|
||||
#############################################
|
||||
|
||||
if standard_logging_payload and (
|
||||
metadata := standard_logging_payload["metadata"]
|
||||
):
|
||||
span.set_attribute(SpanAttributes.METADATA, safe_dumps(metadata))
|
||||
# Set custom metadata for observability and trace enrichment.
|
||||
metadata = (
|
||||
standard_logging_payload.get("metadata")
|
||||
if standard_logging_payload
|
||||
else None
|
||||
)
|
||||
if metadata is not None:
|
||||
safe_set_attribute(span, SpanAttributes.METADATA, safe_dumps(metadata))
|
||||
|
||||
#############################################
|
||||
########## LLM Request Attributes ###########
|
||||
#############################################
|
||||
|
||||
# The name of the LLM a request is being made to
|
||||
# The name of the LLM a request is being made to.
|
||||
if kwargs.get("model"):
|
||||
span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model"))
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_MODEL_NAME,
|
||||
kwargs.get("model"),
|
||||
)
|
||||
|
||||
span.set_attribute(
|
||||
# The LLM request type.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.request.type",
|
||||
standard_logging_payload["call_type"],
|
||||
)
|
||||
|
||||
# The Generative AI Provider: Azure, OpenAI, etc.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_PROVIDER,
|
||||
litellm_params.get("custom_llm_provider", "Unknown"),
|
||||
)
|
||||
|
||||
# The maximum number of tokens the LLM generates for a request.
|
||||
if optional_params.get("max_tokens"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.request.max_tokens",
|
||||
optional_params.get("max_tokens"),
|
||||
)
|
||||
|
||||
# The temperature setting for the LLM request.
|
||||
if optional_params.get("temperature"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.request.temperature",
|
||||
optional_params.get("temperature"),
|
||||
)
|
||||
|
||||
# The top_p sampling setting for the LLM request.
|
||||
if optional_params.get("top_p"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.request.top_p",
|
||||
optional_params.get("top_p"),
|
||||
)
|
||||
|
||||
# Indicates whether response is streamed.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.is_streaming",
|
||||
str(optional_params.get("stream", False)),
|
||||
)
|
||||
|
||||
# Logs the user ID if present.
|
||||
if optional_params.get("user"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.user",
|
||||
optional_params.get("user"),
|
||||
)
|
||||
|
||||
# The unique identifier for the completion.
|
||||
if response_obj and response_obj.get("id"):
|
||||
safe_set_attribute(span, "llm.response.id", response_obj.get("id"))
|
||||
|
||||
# The model used to generate the response.
|
||||
if response_obj and response_obj.get("model"):
|
||||
safe_set_attribute(
|
||||
span,
|
||||
"llm.response.model",
|
||||
response_obj.get("model"),
|
||||
)
|
||||
|
||||
# Required by OpenInference to mark span as LLM kind.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND,
|
||||
OpenInferenceSpanKindValues.LLM.value,
|
||||
)
|
||||
|
@ -50,77 +156,132 @@ def set_attributes(span: Span, kwargs, response_obj):
|
|||
# for /chat/completions
|
||||
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
|
||||
if messages:
|
||||
span.set_attribute(
|
||||
last_message = messages[-1]
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.INPUT_VALUE,
|
||||
messages[-1].get("content", ""), # get the last message for input
|
||||
last_message.get("content", ""),
|
||||
)
|
||||
|
||||
# LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page
|
||||
# LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page.
|
||||
for idx, msg in enumerate(messages):
|
||||
# Set the role per message
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}",
|
||||
msg["role"],
|
||||
prefix = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}"
|
||||
# Set the role per message.
|
||||
safe_set_attribute(
|
||||
span, f"{prefix}.{MessageAttributes.MESSAGE_ROLE}", msg.get("role")
|
||||
)
|
||||
# Set the content per message
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
# Set the content per message.
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
msg.get("content", ""),
|
||||
)
|
||||
|
||||
if standard_logging_payload and (
|
||||
model_params := standard_logging_payload["model_parameters"]
|
||||
):
|
||||
# Capture tools (function definitions) used in the LLM call.
|
||||
tools = optional_params.get("tools")
|
||||
if tools:
|
||||
for idx, tool in enumerate(tools):
|
||||
function = tool.get("function")
|
||||
if not function:
|
||||
continue
|
||||
prefix = f"{SpanAttributes.LLM_TOOLS}.{idx}"
|
||||
safe_set_attribute(
|
||||
span, f"{prefix}.{SpanAttributes.TOOL_NAME}", function.get("name")
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{SpanAttributes.TOOL_DESCRIPTION}",
|
||||
function.get("description"),
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{SpanAttributes.TOOL_PARAMETERS}",
|
||||
json.dumps(function.get("parameters")),
|
||||
)
|
||||
|
||||
# Capture tool calls made during function-calling LLM flows.
|
||||
functions = optional_params.get("functions")
|
||||
if functions:
|
||||
for idx, function in enumerate(functions):
|
||||
prefix = f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{idx}"
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
|
||||
function.get("name"),
|
||||
)
|
||||
|
||||
# Capture invocation parameters and user ID if available.
|
||||
model_params = (
|
||||
standard_logging_payload.get("model_parameters")
|
||||
if standard_logging_payload
|
||||
else None
|
||||
)
|
||||
if model_params:
|
||||
# The Generative AI Provider: Azure, OpenAI, etc.
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_dumps(model_params)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_INVOCATION_PARAMETERS,
|
||||
safe_dumps(model_params),
|
||||
)
|
||||
|
||||
if model_params.get("user"):
|
||||
user_id = model_params.get("user")
|
||||
if user_id is not None:
|
||||
span.set_attribute(SpanAttributes.USER_ID, user_id)
|
||||
safe_set_attribute(span, SpanAttributes.USER_ID, user_id)
|
||||
|
||||
#############################################
|
||||
########## LLM Response Attributes ##########
|
||||
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
|
||||
#############################################
|
||||
if hasattr(response_obj, "get"):
|
||||
for choice in response_obj.get("choices", []):
|
||||
response_message = choice.get("message", {})
|
||||
span.set_attribute(
|
||||
SpanAttributes.OUTPUT_VALUE, response_message.get("content", "")
|
||||
)
|
||||
|
||||
# This shows up under `output_messages` tab on the span page
|
||||
# This code assumes a single response
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
|
||||
response_message.get("role"),
|
||||
)
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
# Captures response tokens, message, and content.
|
||||
if hasattr(response_obj, "get"):
|
||||
for idx, choice in enumerate(response_obj.get("choices", [])):
|
||||
response_message = choice.get("message", {})
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.OUTPUT_VALUE,
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
|
||||
usage = response_obj.get("usage")
|
||||
# This shows up under `output_messages` tab on the span page.
|
||||
prefix = f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}"
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_ROLE}",
|
||||
response_message.get("role"),
|
||||
)
|
||||
safe_set_attribute(
|
||||
span,
|
||||
f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
|
||||
# Token usage info.
|
||||
usage = response_obj and response_obj.get("usage")
|
||||
if usage:
|
||||
span.set_attribute(
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_TOKEN_COUNT_TOTAL,
|
||||
usage.get("total_tokens"),
|
||||
)
|
||||
|
||||
# The number of tokens used in the LLM response (completion).
|
||||
span.set_attribute(
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION,
|
||||
usage.get("completion_tokens"),
|
||||
)
|
||||
|
||||
# The number of tokens used in the LLM prompt.
|
||||
span.set_attribute(
|
||||
safe_set_attribute(
|
||||
span,
|
||||
SpanAttributes.LLM_TOKEN_COUNT_PROMPT,
|
||||
usage.get("prompt_tokens"),
|
||||
)
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error setting arize attributes: {e}")
|
||||
verbose_logger.error(
|
||||
f"[Arize/Phoenix] Failed to set OpenInference span attributes: {e}"
|
||||
)
|
||||
if hasattr(span, "record_exception"):
|
||||
span.record_exception(e)
|
||||
|
|
231
tests/litellm/integrations/arize/test_arize_utils.py
Normal file
231
tests/litellm/integrations/arize/test_arize_utils.py
Normal file
|
@ -0,0 +1,231 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
# Adds the grandparent directory to sys.path to allow importing project modules
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import asyncio
|
||||
import litellm
|
||||
import pytest
|
||||
from litellm.integrations.arize.arize import ArizeLogger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations._types.open_inference import (
|
||||
SpanAttributes,
|
||||
MessageAttributes,
|
||||
ToolCallAttributes,
|
||||
)
|
||||
from litellm.types.utils import Choices, StandardCallbackDynamicParams
|
||||
|
||||
|
||||
def test_arize_set_attributes():
|
||||
"""
|
||||
Test setting attributes for Arize, including all custom LLM attributes.
|
||||
Ensures that the correct span attributes are being added during a request.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
span = MagicMock() # Mocked tracing span to test attribute setting
|
||||
|
||||
# Construct kwargs to simulate a real LLM request scenario
|
||||
kwargs = {
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Basic Request Content"}],
|
||||
"standard_logging_object": {
|
||||
"model_parameters": {"user": "test_user"},
|
||||
"metadata": {"key_1": "value_1", "key_2": None},
|
||||
"call_type": "completion",
|
||||
},
|
||||
"optional_params": {
|
||||
"max_tokens": "100",
|
||||
"temperature": "1",
|
||||
"top_p": "5",
|
||||
"stream": False,
|
||||
"user": "test_user",
|
||||
"tools": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Fetches weather details.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
"functions": [{"name": "get_weather"}, {"name": "get_stock_price"}],
|
||||
},
|
||||
"litellm_params": {"custom_llm_provider": "openai"},
|
||||
}
|
||||
|
||||
# Simulated LLM response object
|
||||
response_obj = ModelResponse(
|
||||
usage={"total_tokens": 100, "completion_tokens": 60, "prompt_tokens": 40},
|
||||
choices=[
|
||||
Choices(message={"role": "assistant", "content": "Basic Response Content"})
|
||||
],
|
||||
model="gpt-4o",
|
||||
id="chatcmpl-ID",
|
||||
)
|
||||
|
||||
# Apply attribute setting via ArizeLogger
|
||||
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
||||
|
||||
# Validate that the expected number of attributes were set
|
||||
assert span.set_attribute.call_count == 28
|
||||
|
||||
# Metadata attached to the span
|
||||
span.set_attribute.assert_any_call(
|
||||
SpanAttributes.METADATA, json.dumps({"key_1": "value_1", "key_2": None})
|
||||
)
|
||||
|
||||
# Basic LLM information
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_MODEL_NAME, "gpt-4o")
|
||||
span.set_attribute.assert_any_call("llm.request.type", "completion")
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_PROVIDER, "openai")
|
||||
|
||||
# LLM generation parameters
|
||||
span.set_attribute.assert_any_call("llm.request.max_tokens", "100")
|
||||
span.set_attribute.assert_any_call("llm.request.temperature", "1")
|
||||
span.set_attribute.assert_any_call("llm.request.top_p", "5")
|
||||
|
||||
# Streaming and user info
|
||||
span.set_attribute.assert_any_call("llm.is_streaming", "False")
|
||||
span.set_attribute.assert_any_call("llm.user", "test_user")
|
||||
|
||||
# Response metadata
|
||||
span.set_attribute.assert_any_call("llm.response.id", "chatcmpl-ID")
|
||||
span.set_attribute.assert_any_call("llm.response.model", "gpt-4o")
|
||||
span.set_attribute.assert_any_call(SpanAttributes.OPENINFERENCE_SPAN_KIND, "LLM")
|
||||
|
||||
# Request message content and metadata
|
||||
span.set_attribute.assert_any_call(
|
||||
SpanAttributes.INPUT_VALUE, "Basic Request Content"
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
|
||||
"user",
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
"Basic Request Content",
|
||||
)
|
||||
|
||||
# Tool call definitions and function names
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_NAME}", "get_weather"
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_DESCRIPTION}",
|
||||
"Fetches weather details.",
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_PARAMETERS}",
|
||||
json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "City name"}
|
||||
},
|
||||
"required": ["location"],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Tool calls captured from optional_params
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.0.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
|
||||
"get_weather",
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.1.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
|
||||
"get_stock_price",
|
||||
)
|
||||
|
||||
# Invocation parameters
|
||||
span.set_attribute.assert_any_call(
|
||||
SpanAttributes.LLM_INVOCATION_PARAMETERS, '{"user": "test_user"}'
|
||||
)
|
||||
|
||||
# User ID
|
||||
span.set_attribute.assert_any_call(SpanAttributes.USER_ID, "test_user")
|
||||
|
||||
# Output message content
|
||||
span.set_attribute.assert_any_call(
|
||||
SpanAttributes.OUTPUT_VALUE, "Basic Response Content"
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
|
||||
"assistant",
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
"Basic Response Content",
|
||||
)
|
||||
|
||||
# Token counts
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, 100)
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 60)
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 40)
|
||||
|
||||
|
||||
class TestArizeLogger(CustomLogger):
|
||||
"""
|
||||
Custom logger implementation to capture standard_callback_dynamic_params.
|
||||
Used to verify that dynamic config keys are being passed to callbacks.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.standard_callback_dynamic_params: Optional[
|
||||
StandardCallbackDynamicParams
|
||||
] = None
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
# Capture dynamic params and print them for verification
|
||||
print("logged kwargs", json.dumps(kwargs, indent=4, default=str))
|
||||
self.standard_callback_dynamic_params = kwargs.get(
|
||||
"standard_callback_dynamic_params"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arize_dynamic_params():
|
||||
"""
|
||||
Test to ensure that dynamic Arize keys (API key and space key)
|
||||
are received inside the callback logger at runtime.
|
||||
"""
|
||||
test_arize_logger = TestArizeLogger()
|
||||
litellm.callbacks = [test_arize_logger]
|
||||
|
||||
# Perform a mocked async completion call to trigger logging
|
||||
await litellm.acompletion(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "Basic Request Content"}],
|
||||
mock_response="test",
|
||||
arize_api_key="test_api_key_dynamic",
|
||||
arize_space_key="test_space_key_dynamic",
|
||||
)
|
||||
|
||||
# Allow for async propagation
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Assert dynamic parameters were received in the callback
|
||||
assert test_arize_logger.standard_callback_dynamic_params is not None
|
||||
assert (
|
||||
test_arize_logger.standard_callback_dynamic_params.get("arize_api_key")
|
||||
== "test_api_key_dynamic"
|
||||
)
|
||||
assert (
|
||||
test_arize_logger.standard_callback_dynamic_params.get("arize_space_key")
|
||||
== "test_space_key_dynamic"
|
||||
)
|
|
@ -1,111 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from unittest.mock import Mock, patch
|
||||
import json
|
||||
import opentelemetry.exporter.otlp.proto.grpc.trace_exporter
|
||||
from typing import Optional
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
from litellm.integrations._types.open_inference import SpanAttributes
|
||||
from litellm.integrations.arize.arize import ArizeConfig, ArizeLogger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.main import completion
|
||||
import litellm
|
||||
from litellm.types.utils import Choices, StandardCallbackDynamicParams
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
|
||||
def test_arize_set_attributes():
|
||||
"""
|
||||
Test setting attributes for Arize
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
span = MagicMock()
|
||||
kwargs = {
|
||||
"role": "user",
|
||||
"content": "simple arize test",
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "basic arize test"}],
|
||||
"standard_logging_object": {
|
||||
"model_parameters": {"user": "test_user"},
|
||||
"metadata": {"key": "value", "key2": None},
|
||||
},
|
||||
}
|
||||
response_obj = ModelResponse(
|
||||
usage={"total_tokens": 100, "completion_tokens": 60, "prompt_tokens": 40},
|
||||
choices=[Choices(message={"role": "assistant", "content": "response content"})],
|
||||
)
|
||||
|
||||
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
||||
|
||||
assert span.set_attribute.call_count == 14
|
||||
span.set_attribute.assert_any_call(
|
||||
SpanAttributes.METADATA, json.dumps({"key": "value", "key2": None})
|
||||
)
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_MODEL_NAME, "gpt-4o")
|
||||
span.set_attribute.assert_any_call(SpanAttributes.OPENINFERENCE_SPAN_KIND, "LLM")
|
||||
span.set_attribute.assert_any_call(SpanAttributes.INPUT_VALUE, "basic arize test")
|
||||
span.set_attribute.assert_any_call("llm.input_messages.0.message.role", "user")
|
||||
span.set_attribute.assert_any_call(
|
||||
"llm.input_messages.0.message.content", "basic arize test"
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
SpanAttributes.LLM_INVOCATION_PARAMETERS, '{"user": "test_user"}'
|
||||
)
|
||||
span.set_attribute.assert_any_call(SpanAttributes.USER_ID, "test_user")
|
||||
span.set_attribute.assert_any_call(SpanAttributes.OUTPUT_VALUE, "response content")
|
||||
span.set_attribute.assert_any_call(
|
||||
"llm.output_messages.0.message.role", "assistant"
|
||||
)
|
||||
span.set_attribute.assert_any_call(
|
||||
"llm.output_messages.0.message.content", "response content"
|
||||
)
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, 100)
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 60)
|
||||
span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 40)
|
||||
|
||||
|
||||
class TestArizeLogger(CustomLogger):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.standard_callback_dynamic_params: Optional[
|
||||
StandardCallbackDynamicParams
|
||||
] = None
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print("logged kwargs", json.dumps(kwargs, indent=4, default=str))
|
||||
self.standard_callback_dynamic_params = kwargs.get(
|
||||
"standard_callback_dynamic_params"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arize_dynamic_params():
|
||||
"""verify arize ai dynamic params are recieved by a callback"""
|
||||
test_arize_logger = TestArizeLogger()
|
||||
litellm.callbacks = [test_arize_logger]
|
||||
await litellm.acompletion(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "basic arize test"}],
|
||||
mock_response="test",
|
||||
arize_api_key="test_api_key_dynamic",
|
||||
arize_space_key="test_space_key_dynamic",
|
||||
)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
assert test_arize_logger.standard_callback_dynamic_params is not None
|
||||
assert (
|
||||
test_arize_logger.standard_callback_dynamic_params.get("arize_api_key")
|
||||
== "test_api_key_dynamic"
|
||||
)
|
||||
assert (
|
||||
test_arize_logger.standard_callback_dynamic_params.get("arize_space_key")
|
||||
== "test_space_key_dynamic"
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue