From 6909d8e11bc66eb66d448eab612f3ec9d7579e19 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 16 Oct 2024 08:33:40 +0530 Subject: [PATCH] fix arize handle optional params (#6243) --- litellm/integrations/arize_ai.py | 171 ++++++++++++++++----------- litellm/proxy/proxy_config.yaml | 2 +- tests/local_testing/test_arize_ai.py | 7 +- 3 files changed, 106 insertions(+), 74 deletions(-) diff --git a/litellm/integrations/arize_ai.py b/litellm/integrations/arize_ai.py index 137c33c48..5a66cfd0c 100644 --- a/litellm/integrations/arize_ai.py +++ b/litellm/integrations/arize_ai.py @@ -4,8 +4,11 @@ arize AI is OTEL compatible this file has Arize ai specific helper functions """ +import json from typing import TYPE_CHECKING, Any, Optional, Union +from litellm._logging import verbose_proxy_logger + if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -14,6 +17,21 @@ else: Span = Any +def make_json_serializable(payload: dict) -> dict: + for key, value in payload.items(): + try: + if isinstance(value, dict): + # recursively sanitize dicts + payload[key] = make_json_serializable(value.copy()) + elif not isinstance(value, (str, int, float, bool, type(None))): + # everything else becomes a string + payload[key] = str(value) + except Exception: + # non blocking if it can't cast to a str + pass + return payload + + def set_arize_ai_attributes(span: Span, kwargs, response_obj): from litellm.integrations._types.open_inference import ( MessageAttributes, @@ -22,93 +40,102 @@ def set_arize_ai_attributes(span: Span, kwargs, response_obj): SpanAttributes, ) - optional_params = kwargs.get("optional_params", {}) - # litellm_params = kwargs.get("litellm_params", {}) or {} + try: - ############################################# - ############ LLM CALL METADATA ############## - ############################################# - # commented out for now - looks like Arize AI could not log this - # metadata = litellm_params.get("metadata", {}) or {} - # span.set_attribute(SpanAttributes.METADATA, str(metadata)) + optional_params = kwargs.get("optional_params", {}) + # litellm_params = kwargs.get("litellm_params", {}) or {} - ############################################# - ########## LLM Request Attributes ########### - ############################################# + ############################################# + ############ LLM CALL METADATA ############## + ############################################# + # commented out for now - looks like Arize AI could not log this + # metadata = litellm_params.get("metadata", {}) or {} + # span.set_attribute(SpanAttributes.METADATA, str(metadata)) - # The name of the LLM a request is being made to - if kwargs.get("model"): - span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model")) + ############################################# + ########## LLM Request Attributes ########### + ############################################# - span.set_attribute( - SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.LLM.value - ) - messages = kwargs.get("messages") + # The name of the LLM a request is being made to + if kwargs.get("model"): + span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model")) - # for /chat/completions - # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions - if messages: span.set_attribute( - SpanAttributes.INPUT_VALUE, - messages[-1].get("content", ""), # get the last message for input + SpanAttributes.OPENINFERENCE_SPAN_KIND, + OpenInferenceSpanKindValues.LLM.value, ) + messages = kwargs.get("messages") - # LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page - for idx, msg in enumerate(messages): - # Set the role per message + # for /chat/completions + # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions + if messages: span.set_attribute( - f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}", - msg["role"], - ) - # Set the content per message - span.set_attribute( - f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}", - msg.get("content", ""), + SpanAttributes.INPUT_VALUE, + messages[-1].get("content", ""), # get the last message for input ) - # The Generative AI Provider: Azure, OpenAI, etc. - span.set_attribute(SpanAttributes.LLM_INVOCATION_PARAMETERS, str(optional_params)) + # LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page + for idx, msg in enumerate(messages): + # Set the role per message + span.set_attribute( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}", + msg["role"], + ) + # Set the content per message + span.set_attribute( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}", + msg.get("content", ""), + ) - if optional_params.get("user"): - span.set_attribute(SpanAttributes.USER_ID, optional_params.get("user")) - - ############################################# - ########## LLM Response Attributes ########## - # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions - ############################################# - for choice in response_obj.get("choices"): - response_message = choice.get("message", {}) + # The Generative AI Provider: Azure, OpenAI, etc. + _optional_params = make_json_serializable(optional_params) + _json_optional_params = json.dumps(_optional_params) span.set_attribute( - SpanAttributes.OUTPUT_VALUE, response_message.get("content", "") + SpanAttributes.LLM_INVOCATION_PARAMETERS, _json_optional_params ) - # This shows up under `output_messages` tab on the span page - # This code assumes a single response - span.set_attribute( - f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}", - response_message["role"], - ) - span.set_attribute( - f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}", - response_message.get("content", ""), - ) + if optional_params.get("user"): + span.set_attribute(SpanAttributes.USER_ID, optional_params.get("user")) - usage = response_obj.get("usage") - if usage: - span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_TOTAL, - usage.get("total_tokens"), - ) + ############################################# + ########## LLM Response Attributes ########## + # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions + ############################################# + for choice in response_obj.get("choices"): + response_message = choice.get("message", {}) + span.set_attribute( + SpanAttributes.OUTPUT_VALUE, response_message.get("content", "") + ) - # The number of tokens used in the LLM response (completion). - span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, - usage.get("completion_tokens"), - ) + # This shows up under `output_messages` tab on the span page + # This code assumes a single response + span.set_attribute( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}", + response_message["role"], + ) + span.set_attribute( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}", + response_message.get("content", ""), + ) - # The number of tokens used in the LLM prompt. - span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_PROMPT, - usage.get("prompt_tokens"), - ) - pass + usage = response_obj.get("usage") + if usage: + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_TOTAL, + usage.get("total_tokens"), + ) + + # The number of tokens used in the LLM response (completion). + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, + usage.get("completion_tokens"), + ) + + # The number of tokens used in the LLM prompt. + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_PROMPT, + usage.get("prompt_tokens"), + ) + pass + except Exception as e: + verbose_proxy_logger.error(f"Error setting arize attributes: {e}") diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 50d1a91bf..7c70332fd 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -6,5 +6,5 @@ model_list: api_base: https://exampleopenaiendpoint-production.up.railway.app/ litellm_settings: - callbacks: ["otel"] + callbacks: ["arize"] diff --git a/tests/local_testing/test_arize_ai.py b/tests/local_testing/test_arize_ai.py index dfc00446e..f6ccc75f2 100644 --- a/tests/local_testing/test_arize_ai.py +++ b/tests/local_testing/test_arize_ai.py @@ -8,7 +8,7 @@ from dotenv import load_dotenv from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter import litellm -from litellm._logging import verbose_logger +from litellm._logging import verbose_logger, verbose_proxy_logger from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig load_dotenv() @@ -18,6 +18,9 @@ import logging @pytest.mark.asyncio() async def test_async_otel_callback(): litellm.set_verbose = True + + verbose_proxy_logger.setLevel(logging.DEBUG) + verbose_logger.setLevel(logging.DEBUG) litellm.success_callback = ["arize"] await litellm.acompletion( @@ -27,3 +30,5 @@ async def test_async_otel_callback(): temperature=0.1, user="OTEL_USER", ) + + await asyncio.sleep(2)