From 96e31d205c710acb20f03ee950290f833fd3c1f6 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 22 Apr 2025 21:34:51 -0700 Subject: [PATCH] feat: Added Missing Attributes For Arize & Phoenix Integration (#10043) (#10215) * feat: Added Missing Attributes For Arize & Phoenix Integration * chore: Added noqa for PLR0915 to suppress warning * chore: Moved Contributor Test to Correct Location * chore: Removed Redundant Fallback Co-authored-by: Ali Saleh --- litellm/integrations/_types/open_inference.py | 103 +++++++ litellm/integrations/arize/_utils.py | 251 ++++++++++++++---- .../integrations/arize/test_arize_utils.py | 231 ++++++++++++++++ .../test_arize_logging.py | 111 -------- 4 files changed, 540 insertions(+), 156 deletions(-) create mode 100644 tests/litellm/integrations/arize/test_arize_utils.py delete mode 100644 tests/logging_callback_tests/test_arize_logging.py diff --git a/litellm/integrations/_types/open_inference.py b/litellm/integrations/_types/open_inference.py index bcfabe9b7b..65ecadcf37 100644 --- a/litellm/integrations/_types/open_inference.py +++ b/litellm/integrations/_types/open_inference.py @@ -45,6 +45,14 @@ class SpanAttributes: """ The name of the model being used. """ + LLM_PROVIDER = "llm.provider" + """ + The provider of the model, such as OpenAI, Azure, Google, etc. + """ + LLM_SYSTEM = "llm.system" + """ + The AI product as identified by the client or server + """ LLM_PROMPTS = "llm.prompts" """ Prompts provided to a completions API. @@ -65,15 +73,40 @@ class SpanAttributes: """ Number of tokens in the prompt. """ + LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE = "llm.token_count.prompt_details.cache_write" + """ + Number of tokens in the prompt that were written to cache. + """ + LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = "llm.token_count.prompt_details.cache_read" + """ + Number of tokens in the prompt that were read from cache. + """ + LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO = "llm.token_count.prompt_details.audio" + """ + The number of audio input tokens presented in the prompt + """ LLM_TOKEN_COUNT_COMPLETION = "llm.token_count.completion" """ Number of tokens in the completion. """ + LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = "llm.token_count.completion_details.reasoning" + """ + Number of tokens used for reasoning steps in the completion. + """ + LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO = "llm.token_count.completion_details.audio" + """ + The number of audio input tokens generated by the model + """ LLM_TOKEN_COUNT_TOTAL = "llm.token_count.total" """ Total number of tokens, including both prompt and completion. """ + LLM_TOOLS = "llm.tools" + """ + List of tools that are advertised to the LLM to be able to call + """ + TOOL_NAME = "tool.name" """ Name of the tool being used. @@ -112,6 +145,19 @@ class SpanAttributes: The id of the user """ + PROMPT_VENDOR = "prompt.vendor" + """ + The vendor or origin of the prompt, e.g. a prompt library, a specialized service, etc. + """ + PROMPT_ID = "prompt.id" + """ + A vendor-specific id used to locate the prompt. + """ + PROMPT_URL = "prompt.url" + """ + A vendor-specific url used to locate the prompt. + """ + class MessageAttributes: """ @@ -151,6 +197,10 @@ class MessageAttributes: The JSON string representing the arguments passed to the function during a function call. """ + MESSAGE_TOOL_CALL_ID = "message.tool_call_id" + """ + The id of the tool call. + """ class MessageContentAttributes: @@ -186,6 +236,25 @@ class ImageAttributes: """ +class AudioAttributes: + """ + Attributes for audio + """ + + AUDIO_URL = "audio.url" + """ + The url to an audio file + """ + AUDIO_MIME_TYPE = "audio.mime_type" + """ + The mime type of the audio file + """ + AUDIO_TRANSCRIPT = "audio.transcript" + """ + The transcript of the audio file + """ + + class DocumentAttributes: """ Attributes for a document. @@ -257,6 +326,10 @@ class ToolCallAttributes: Attributes for a tool call """ + TOOL_CALL_ID = "tool_call.id" + """ + The id of the tool call. + """ TOOL_CALL_FUNCTION_NAME = "tool_call.function.name" """ The name of function that is being called during a tool call. @@ -268,6 +341,18 @@ class ToolCallAttributes: """ +class ToolAttributes: + """ + Attributes for a tools + """ + + TOOL_JSON_SCHEMA = "tool.json_schema" + """ + The json schema of a tool input, It is RECOMMENDED that this be in the + OpenAI tool calling format: https://platform.openai.com/docs/assistants/tools + """ + + class OpenInferenceSpanKindValues(Enum): TOOL = "TOOL" CHAIN = "CHAIN" @@ -284,3 +369,21 @@ class OpenInferenceSpanKindValues(Enum): class OpenInferenceMimeTypeValues(Enum): TEXT = "text/plain" JSON = "application/json" + + +class OpenInferenceLLMSystemValues(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + COHERE = "cohere" + MISTRALAI = "mistralai" + VERTEXAI = "vertexai" + + +class OpenInferenceLLMProviderValues(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + COHERE = "cohere" + MISTRALAI = "mistralai" + GOOGLE = "google" + AZURE = "azure" + AWS = "aws" diff --git a/litellm/integrations/arize/_utils.py b/litellm/integrations/arize/_utils.py index 5a090968b4..e93ef128b4 100644 --- a/litellm/integrations/arize/_utils.py +++ b/litellm/integrations/arize/_utils.py @@ -1,3 +1,4 @@ +import json from typing import TYPE_CHECKING, Any, Optional, Union from litellm._logging import verbose_logger @@ -12,36 +13,141 @@ else: Span = Any -def set_attributes(span: Span, kwargs, response_obj): +def cast_as_primitive_value_type(value) -> Union[str, bool, int, float]: + """ + Converts a value to an OTEL-supported primitive for Arize/Phoenix observability. + """ + if value is None: + return "" + if isinstance(value, (str, bool, int, float)): + return value + try: + return str(value) + except Exception: + return "" + + +def safe_set_attribute(span: Span, key: str, value: Any): + """ + Sets a span attribute safely with OTEL-compliant primitive typing for Arize/Phoenix. + """ + primitive_value = cast_as_primitive_value_type(value) + span.set_attribute(key, primitive_value) + + +def set_attributes(span: Span, kwargs, response_obj): # noqa: PLR0915 + """ + Populates span with OpenInference-compliant LLM attributes for Arize and Phoenix tracing. + """ from litellm.integrations._types.open_inference import ( MessageAttributes, OpenInferenceSpanKindValues, SpanAttributes, + ToolCallAttributes, ) try: + optional_params = kwargs.get("optional_params", {}) + litellm_params = kwargs.get("litellm_params", {}) standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object" ) + if standard_logging_payload is None: + raise ValueError("standard_logging_object not found in kwargs") ############################################# ############ LLM CALL METADATA ############## ############################################# - if standard_logging_payload and ( - metadata := standard_logging_payload["metadata"] - ): - span.set_attribute(SpanAttributes.METADATA, safe_dumps(metadata)) + # Set custom metadata for observability and trace enrichment. + metadata = ( + standard_logging_payload.get("metadata") + if standard_logging_payload + else None + ) + if metadata is not None: + safe_set_attribute(span, SpanAttributes.METADATA, safe_dumps(metadata)) ############################################# ########## LLM Request Attributes ########### ############################################# - # The name of the LLM a request is being made to + # The name of the LLM a request is being made to. if kwargs.get("model"): - span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model")) + safe_set_attribute( + span, + SpanAttributes.LLM_MODEL_NAME, + kwargs.get("model"), + ) - span.set_attribute( + # The LLM request type. + safe_set_attribute( + span, + "llm.request.type", + standard_logging_payload["call_type"], + ) + + # The Generative AI Provider: Azure, OpenAI, etc. + safe_set_attribute( + span, + SpanAttributes.LLM_PROVIDER, + litellm_params.get("custom_llm_provider", "Unknown"), + ) + + # The maximum number of tokens the LLM generates for a request. + if optional_params.get("max_tokens"): + safe_set_attribute( + span, + "llm.request.max_tokens", + optional_params.get("max_tokens"), + ) + + # The temperature setting for the LLM request. + if optional_params.get("temperature"): + safe_set_attribute( + span, + "llm.request.temperature", + optional_params.get("temperature"), + ) + + # The top_p sampling setting for the LLM request. + if optional_params.get("top_p"): + safe_set_attribute( + span, + "llm.request.top_p", + optional_params.get("top_p"), + ) + + # Indicates whether response is streamed. + safe_set_attribute( + span, + "llm.is_streaming", + str(optional_params.get("stream", False)), + ) + + # Logs the user ID if present. + if optional_params.get("user"): + safe_set_attribute( + span, + "llm.user", + optional_params.get("user"), + ) + + # The unique identifier for the completion. + if response_obj and response_obj.get("id"): + safe_set_attribute(span, "llm.response.id", response_obj.get("id")) + + # The model used to generate the response. + if response_obj and response_obj.get("model"): + safe_set_attribute( + span, + "llm.response.model", + response_obj.get("model"), + ) + + # Required by OpenInference to mark span as LLM kind. + safe_set_attribute( + span, SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.LLM.value, ) @@ -50,77 +156,132 @@ def set_attributes(span: Span, kwargs, response_obj): # for /chat/completions # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions if messages: - span.set_attribute( + last_message = messages[-1] + safe_set_attribute( + span, SpanAttributes.INPUT_VALUE, - messages[-1].get("content", ""), # get the last message for input + last_message.get("content", ""), ) - # LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page + # LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page. for idx, msg in enumerate(messages): - # Set the role per message - span.set_attribute( - f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}", - msg["role"], + prefix = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}" + # Set the role per message. + safe_set_attribute( + span, f"{prefix}.{MessageAttributes.MESSAGE_ROLE}", msg.get("role") ) - # Set the content per message - span.set_attribute( - f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}", + # Set the content per message. + safe_set_attribute( + span, + f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}", msg.get("content", ""), ) - if standard_logging_payload and ( - model_params := standard_logging_payload["model_parameters"] - ): + # Capture tools (function definitions) used in the LLM call. + tools = optional_params.get("tools") + if tools: + for idx, tool in enumerate(tools): + function = tool.get("function") + if not function: + continue + prefix = f"{SpanAttributes.LLM_TOOLS}.{idx}" + safe_set_attribute( + span, f"{prefix}.{SpanAttributes.TOOL_NAME}", function.get("name") + ) + safe_set_attribute( + span, + f"{prefix}.{SpanAttributes.TOOL_DESCRIPTION}", + function.get("description"), + ) + safe_set_attribute( + span, + f"{prefix}.{SpanAttributes.TOOL_PARAMETERS}", + json.dumps(function.get("parameters")), + ) + + # Capture tool calls made during function-calling LLM flows. + functions = optional_params.get("functions") + if functions: + for idx, function in enumerate(functions): + prefix = f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{idx}" + safe_set_attribute( + span, + f"{prefix}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", + function.get("name"), + ) + + # Capture invocation parameters and user ID if available. + model_params = ( + standard_logging_payload.get("model_parameters") + if standard_logging_payload + else None + ) + if model_params: # The Generative AI Provider: Azure, OpenAI, etc. - span.set_attribute( - SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_dumps(model_params) + safe_set_attribute( + span, + SpanAttributes.LLM_INVOCATION_PARAMETERS, + safe_dumps(model_params), ) if model_params.get("user"): user_id = model_params.get("user") if user_id is not None: - span.set_attribute(SpanAttributes.USER_ID, user_id) + safe_set_attribute(span, SpanAttributes.USER_ID, user_id) ############################################# ########## LLM Response Attributes ########## - # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions ############################################# - if hasattr(response_obj, "get"): - for choice in response_obj.get("choices", []): - response_message = choice.get("message", {}) - span.set_attribute( - SpanAttributes.OUTPUT_VALUE, response_message.get("content", "") - ) - # This shows up under `output_messages` tab on the span page - # This code assumes a single response - span.set_attribute( - f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}", - response_message.get("role"), - ) - span.set_attribute( - f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}", + # Captures response tokens, message, and content. + if hasattr(response_obj, "get"): + for idx, choice in enumerate(response_obj.get("choices", [])): + response_message = choice.get("message", {}) + safe_set_attribute( + span, + SpanAttributes.OUTPUT_VALUE, response_message.get("content", ""), ) - usage = response_obj.get("usage") + # This shows up under `output_messages` tab on the span page. + prefix = f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}" + safe_set_attribute( + span, + f"{prefix}.{MessageAttributes.MESSAGE_ROLE}", + response_message.get("role"), + ) + safe_set_attribute( + span, + f"{prefix}.{MessageAttributes.MESSAGE_CONTENT}", + response_message.get("content", ""), + ) + + # Token usage info. + usage = response_obj and response_obj.get("usage") if usage: - span.set_attribute( + safe_set_attribute( + span, SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage.get("total_tokens"), ) # The number of tokens used in the LLM response (completion). - span.set_attribute( + safe_set_attribute( + span, SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage.get("completion_tokens"), ) # The number of tokens used in the LLM prompt. - span.set_attribute( + safe_set_attribute( + span, SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage.get("prompt_tokens"), ) - pass + except Exception as e: - verbose_logger.error(f"Error setting arize attributes: {e}") + verbose_logger.error( + f"[Arize/Phoenix] Failed to set OpenInference span attributes: {e}" + ) + if hasattr(span, "record_exception"): + span.record_exception(e) diff --git a/tests/litellm/integrations/arize/test_arize_utils.py b/tests/litellm/integrations/arize/test_arize_utils.py new file mode 100644 index 0000000000..bea42faaa8 --- /dev/null +++ b/tests/litellm/integrations/arize/test_arize_utils.py @@ -0,0 +1,231 @@ +import json +import os +import sys +from typing import Optional + +# Adds the grandparent directory to sys.path to allow importing project modules +sys.path.insert(0, os.path.abspath("../..")) + +import asyncio +import litellm +import pytest +from litellm.integrations.arize.arize import ArizeLogger +from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations._types.open_inference import ( + SpanAttributes, + MessageAttributes, + ToolCallAttributes, +) +from litellm.types.utils import Choices, StandardCallbackDynamicParams + + +def test_arize_set_attributes(): + """ + Test setting attributes for Arize, including all custom LLM attributes. + Ensures that the correct span attributes are being added during a request. + """ + from unittest.mock import MagicMock + from litellm.types.utils import ModelResponse + + span = MagicMock() # Mocked tracing span to test attribute setting + + # Construct kwargs to simulate a real LLM request scenario + kwargs = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Basic Request Content"}], + "standard_logging_object": { + "model_parameters": {"user": "test_user"}, + "metadata": {"key_1": "value_1", "key_2": None}, + "call_type": "completion", + }, + "optional_params": { + "max_tokens": "100", + "temperature": "1", + "top_p": "5", + "stream": False, + "user": "test_user", + "tools": [ + { + "function": { + "name": "get_weather", + "description": "Fetches weather details.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name", + } + }, + "required": ["location"], + }, + } + } + ], + "functions": [{"name": "get_weather"}, {"name": "get_stock_price"}], + }, + "litellm_params": {"custom_llm_provider": "openai"}, + } + + # Simulated LLM response object + response_obj = ModelResponse( + usage={"total_tokens": 100, "completion_tokens": 60, "prompt_tokens": 40}, + choices=[ + Choices(message={"role": "assistant", "content": "Basic Response Content"}) + ], + model="gpt-4o", + id="chatcmpl-ID", + ) + + # Apply attribute setting via ArizeLogger + ArizeLogger.set_arize_attributes(span, kwargs, response_obj) + + # Validate that the expected number of attributes were set + assert span.set_attribute.call_count == 28 + + # Metadata attached to the span + span.set_attribute.assert_any_call( + SpanAttributes.METADATA, json.dumps({"key_1": "value_1", "key_2": None}) + ) + + # Basic LLM information + span.set_attribute.assert_any_call(SpanAttributes.LLM_MODEL_NAME, "gpt-4o") + span.set_attribute.assert_any_call("llm.request.type", "completion") + span.set_attribute.assert_any_call(SpanAttributes.LLM_PROVIDER, "openai") + + # LLM generation parameters + span.set_attribute.assert_any_call("llm.request.max_tokens", "100") + span.set_attribute.assert_any_call("llm.request.temperature", "1") + span.set_attribute.assert_any_call("llm.request.top_p", "5") + + # Streaming and user info + span.set_attribute.assert_any_call("llm.is_streaming", "False") + span.set_attribute.assert_any_call("llm.user", "test_user") + + # Response metadata + span.set_attribute.assert_any_call("llm.response.id", "chatcmpl-ID") + span.set_attribute.assert_any_call("llm.response.model", "gpt-4o") + span.set_attribute.assert_any_call(SpanAttributes.OPENINFERENCE_SPAN_KIND, "LLM") + + # Request message content and metadata + span.set_attribute.assert_any_call( + SpanAttributes.INPUT_VALUE, "Basic Request Content" + ) + span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}", + "user", + ) + span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}", + "Basic Request Content", + ) + + # Tool call definitions and function names + span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_NAME}", "get_weather" + ) + span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_DESCRIPTION}", + "Fetches weather details.", + ) + span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_TOOLS}.0.{SpanAttributes.TOOL_PARAMETERS}", + json.dumps( + { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"} + }, + "required": ["location"], + } + ), + ) + + # Tool calls captured from optional_params + span.set_attribute.assert_any_call( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.0.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", + "get_weather", + ) + span.set_attribute.assert_any_call( + f"{MessageAttributes.MESSAGE_TOOL_CALLS}.1.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", + "get_stock_price", + ) + + # Invocation parameters + span.set_attribute.assert_any_call( + SpanAttributes.LLM_INVOCATION_PARAMETERS, '{"user": "test_user"}' + ) + + # User ID + span.set_attribute.assert_any_call(SpanAttributes.USER_ID, "test_user") + + # Output message content + span.set_attribute.assert_any_call( + SpanAttributes.OUTPUT_VALUE, "Basic Response Content" + ) + span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}", + "assistant", + ) + span.set_attribute.assert_any_call( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}", + "Basic Response Content", + ) + + # Token counts + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, 100) + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 60) + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 40) + + +class TestArizeLogger(CustomLogger): + """ + Custom logger implementation to capture standard_callback_dynamic_params. + Used to verify that dynamic config keys are being passed to callbacks. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.standard_callback_dynamic_params: Optional[ + StandardCallbackDynamicParams + ] = None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + # Capture dynamic params and print them for verification + print("logged kwargs", json.dumps(kwargs, indent=4, default=str)) + self.standard_callback_dynamic_params = kwargs.get( + "standard_callback_dynamic_params" + ) + + +@pytest.mark.asyncio +async def test_arize_dynamic_params(): + """ + Test to ensure that dynamic Arize keys (API key and space key) + are received inside the callback logger at runtime. + """ + test_arize_logger = TestArizeLogger() + litellm.callbacks = [test_arize_logger] + + # Perform a mocked async completion call to trigger logging + await litellm.acompletion( + model="gpt-4o", + messages=[{"role": "user", "content": "Basic Request Content"}], + mock_response="test", + arize_api_key="test_api_key_dynamic", + arize_space_key="test_space_key_dynamic", + ) + + # Allow for async propagation + await asyncio.sleep(2) + + # Assert dynamic parameters were received in the callback + assert test_arize_logger.standard_callback_dynamic_params is not None + assert ( + test_arize_logger.standard_callback_dynamic_params.get("arize_api_key") + == "test_api_key_dynamic" + ) + assert ( + test_arize_logger.standard_callback_dynamic_params.get("arize_space_key") + == "test_space_key_dynamic" + ) diff --git a/tests/logging_callback_tests/test_arize_logging.py b/tests/logging_callback_tests/test_arize_logging.py deleted file mode 100644 index aca3ae9a02..0000000000 --- a/tests/logging_callback_tests/test_arize_logging.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import sys -import time -from unittest.mock import Mock, patch -import json -import opentelemetry.exporter.otlp.proto.grpc.trace_exporter -from typing import Optional - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system-path -from litellm.integrations._types.open_inference import SpanAttributes -from litellm.integrations.arize.arize import ArizeConfig, ArizeLogger -from litellm.integrations.custom_logger import CustomLogger -from litellm.main import completion -import litellm -from litellm.types.utils import Choices, StandardCallbackDynamicParams -import pytest -import asyncio - - -def test_arize_set_attributes(): - """ - Test setting attributes for Arize - """ - from unittest.mock import MagicMock - from litellm.types.utils import ModelResponse - - span = MagicMock() - kwargs = { - "role": "user", - "content": "simple arize test", - "model": "gpt-4o", - "messages": [{"role": "user", "content": "basic arize test"}], - "standard_logging_object": { - "model_parameters": {"user": "test_user"}, - "metadata": {"key": "value", "key2": None}, - }, - } - response_obj = ModelResponse( - usage={"total_tokens": 100, "completion_tokens": 60, "prompt_tokens": 40}, - choices=[Choices(message={"role": "assistant", "content": "response content"})], - ) - - ArizeLogger.set_arize_attributes(span, kwargs, response_obj) - - assert span.set_attribute.call_count == 14 - span.set_attribute.assert_any_call( - SpanAttributes.METADATA, json.dumps({"key": "value", "key2": None}) - ) - span.set_attribute.assert_any_call(SpanAttributes.LLM_MODEL_NAME, "gpt-4o") - span.set_attribute.assert_any_call(SpanAttributes.OPENINFERENCE_SPAN_KIND, "LLM") - span.set_attribute.assert_any_call(SpanAttributes.INPUT_VALUE, "basic arize test") - span.set_attribute.assert_any_call("llm.input_messages.0.message.role", "user") - span.set_attribute.assert_any_call( - "llm.input_messages.0.message.content", "basic arize test" - ) - span.set_attribute.assert_any_call( - SpanAttributes.LLM_INVOCATION_PARAMETERS, '{"user": "test_user"}' - ) - span.set_attribute.assert_any_call(SpanAttributes.USER_ID, "test_user") - span.set_attribute.assert_any_call(SpanAttributes.OUTPUT_VALUE, "response content") - span.set_attribute.assert_any_call( - "llm.output_messages.0.message.role", "assistant" - ) - span.set_attribute.assert_any_call( - "llm.output_messages.0.message.content", "response content" - ) - span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, 100) - span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 60) - span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 40) - - -class TestArizeLogger(CustomLogger): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.standard_callback_dynamic_params: Optional[ - StandardCallbackDynamicParams - ] = None - - async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - print("logged kwargs", json.dumps(kwargs, indent=4, default=str)) - self.standard_callback_dynamic_params = kwargs.get( - "standard_callback_dynamic_params" - ) - - -@pytest.mark.asyncio -async def test_arize_dynamic_params(): - """verify arize ai dynamic params are recieved by a callback""" - test_arize_logger = TestArizeLogger() - litellm.callbacks = [test_arize_logger] - await litellm.acompletion( - model="gpt-4o", - messages=[{"role": "user", "content": "basic arize test"}], - mock_response="test", - arize_api_key="test_api_key_dynamic", - arize_space_key="test_space_key_dynamic", - ) - - await asyncio.sleep(2) - - assert test_arize_logger.standard_callback_dynamic_params is not None - assert ( - test_arize_logger.standard_callback_dynamic_params.get("arize_api_key") - == "test_api_key_dynamic" - ) - assert ( - test_arize_logger.standard_callback_dynamic_params.get("arize_space_key") - == "test_space_key_dynamic" - )