diff --git a/litellm/integrations/arize/_utils.py b/litellm/integrations/arize/_utils.py index aded5f7469..4697245f57 100644 --- a/litellm/integrations/arize/_utils.py +++ b/litellm/integrations/arize/_utils.py @@ -19,13 +19,16 @@ def set_attributes(span: Span, kwargs, response_obj): ) try: - litellm_params = kwargs.get("litellm_params", {}) or {} - + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + ############################################# ############ LLM CALL METADATA ############## ############################################# - metadata = litellm_params.get("metadata", {}) or {} - span.set_attribute(SpanAttributes.METADATA, str(metadata)) + + if standard_logging_payload and (metadata := standard_logging_payload["metadata"]): + span.set_attribute(SpanAttributes.METADATA, json.dumps(metadata)) ############################################# ########## LLM Request Attributes ########### @@ -62,9 +65,6 @@ def set_attributes(span: Span, kwargs, response_obj): msg.get("content", ""), ) - standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( - "standard_logging_object" - ) if standard_logging_payload and (model_params := standard_logging_payload["model_parameters"]): # The Generative AI Provider: Azure, OpenAI, etc. span.set_attribute( diff --git a/tests/local_testing/test_arize_ai.py b/tests/local_testing/test_arize_ai.py index 0caf1b5e46..309cc689ea 100644 --- a/tests/local_testing/test_arize_ai.py +++ b/tests/local_testing/test_arize_ai.py @@ -1,4 +1,5 @@ import asyncio +import json import logging from litellm import Choices @@ -75,8 +76,7 @@ def test_arize_set_attributes(): "content": "simple arize test", "model": "gpt-4o", "messages": [{"role": "user", "content": "basic arize test"}], - "litellm_params": {"metadata": {"key": "value"}}, - "standard_logging_object": {"model_parameters": {"user": "test_user"}} + "standard_logging_object": {"model_parameters": {"user": "test_user"}, "metadata": {"key": "value", "key2": None}}, } response_obj = ModelResponse(usage={"total_tokens": 100, "completion_tokens": 60, "prompt_tokens": 40}, choices=[Choices(message={"role": "assistant", "content": "response content"})]) @@ -84,7 +84,7 @@ def test_arize_set_attributes(): ArizeLogger.set_arize_attributes(span, kwargs, response_obj) assert span.set_attribute.call_count == 14 - span.set_attribute.assert_any_call(SpanAttributes.METADATA, str({"key": "value"})) + span.set_attribute.assert_any_call(SpanAttributes.METADATA, json.dumps({"key": "value", "key2": None})) span.set_attribute.assert_any_call(SpanAttributes.LLM_MODEL_NAME, "gpt-4o") span.set_attribute.assert_any_call(SpanAttributes.OPENINFERENCE_SPAN_KIND, "LLM") span.set_attribute.assert_any_call(SpanAttributes.INPUT_VALUE, "basic arize test") diff --git a/tests/logging_callback_tests/test_arize_logging.py b/tests/logging_callback_tests/test_arize_logging.py index 59257ad905..27f067d728 100644 --- a/tests/logging_callback_tests/test_arize_logging.py +++ b/tests/logging_callback_tests/test_arize_logging.py @@ -19,6 +19,8 @@ def test_arize_callback(): os.environ["ARIZE_API_KEY"] = "test_api_key" os.environ["ARIZE_ENDPOINT"] = "https://otlp.arize.com/v1" + # Set the batch span processor to quickly flush after a span has been added + # This is to ensure that the span is exported before the test ends os.environ["OTEL_BSP_MAX_QUEUE_SIZE"] = "1" os.environ["OTEL_BSP_MAX_EXPORT_BATCH_SIZE"] = "1" os.environ["OTEL_BSP_SCHEDULE_DELAY_MILLIS"] = "1" @@ -36,5 +38,5 @@ def test_arize_callback(): mock_response="hello there!", ) - time.sleep(1) + time.sleep(1) # Wait for the batch span processor to flush assert patched_export.called