fix arize handle optional params (#6243)

This commit is contained in:
Ishaan Jaff 2024-10-16 08:33:40 +05:30 committed by GitHub
parent 1994100028
commit 6909d8e11b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 106 additions and 74 deletions

View file

@ -4,8 +4,11 @@ arize AI is OTEL compatible
this file has Arize ai specific helper functions this file has Arize ai specific helper functions
""" """
import json
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from litellm._logging import verbose_proxy_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
@ -14,6 +17,21 @@ else:
Span = Any Span = Any
def make_json_serializable(payload: dict) -> dict:
for key, value in payload.items():
try:
if isinstance(value, dict):
# recursively sanitize dicts
payload[key] = make_json_serializable(value.copy())
elif not isinstance(value, (str, int, float, bool, type(None))):
# everything else becomes a string
payload[key] = str(value)
except Exception:
# non blocking if it can't cast to a str
pass
return payload
def set_arize_ai_attributes(span: Span, kwargs, response_obj): def set_arize_ai_attributes(span: Span, kwargs, response_obj):
from litellm.integrations._types.open_inference import ( from litellm.integrations._types.open_inference import (
MessageAttributes, MessageAttributes,
@ -22,93 +40,102 @@ def set_arize_ai_attributes(span: Span, kwargs, response_obj):
SpanAttributes, SpanAttributes,
) )
optional_params = kwargs.get("optional_params", {}) try:
# litellm_params = kwargs.get("litellm_params", {}) or {}
############################################# optional_params = kwargs.get("optional_params", {})
############ LLM CALL METADATA ############## # litellm_params = kwargs.get("litellm_params", {}) or {}
#############################################
# commented out for now - looks like Arize AI could not log this
# metadata = litellm_params.get("metadata", {}) or {}
# span.set_attribute(SpanAttributes.METADATA, str(metadata))
############################################# #############################################
########## LLM Request Attributes ########### ############ LLM CALL METADATA ##############
############################################# #############################################
# commented out for now - looks like Arize AI could not log this
# metadata = litellm_params.get("metadata", {}) or {}
# span.set_attribute(SpanAttributes.METADATA, str(metadata))
# The name of the LLM a request is being made to #############################################
if kwargs.get("model"): ########## LLM Request Attributes ###########
span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model")) #############################################
span.set_attribute( # The name of the LLM a request is being made to
SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.LLM.value if kwargs.get("model"):
) span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model"))
messages = kwargs.get("messages")
# for /chat/completions
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
if messages:
span.set_attribute( span.set_attribute(
SpanAttributes.INPUT_VALUE, SpanAttributes.OPENINFERENCE_SPAN_KIND,
messages[-1].get("content", ""), # get the last message for input OpenInferenceSpanKindValues.LLM.value,
) )
messages = kwargs.get("messages")
# LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page # for /chat/completions
for idx, msg in enumerate(messages): # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
# Set the role per message if messages:
span.set_attribute( span.set_attribute(
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}", SpanAttributes.INPUT_VALUE,
msg["role"], messages[-1].get("content", ""), # get the last message for input
)
# Set the content per message
span.set_attribute(
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}",
msg.get("content", ""),
) )
# The Generative AI Provider: Azure, OpenAI, etc. # LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page
span.set_attribute(SpanAttributes.LLM_INVOCATION_PARAMETERS, str(optional_params)) for idx, msg in enumerate(messages):
# Set the role per message
span.set_attribute(
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}",
msg["role"],
)
# Set the content per message
span.set_attribute(
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}",
msg.get("content", ""),
)
if optional_params.get("user"): # The Generative AI Provider: Azure, OpenAI, etc.
span.set_attribute(SpanAttributes.USER_ID, optional_params.get("user")) _optional_params = make_json_serializable(optional_params)
_json_optional_params = json.dumps(_optional_params)
#############################################
########## LLM Response Attributes ##########
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
#############################################
for choice in response_obj.get("choices"):
response_message = choice.get("message", {})
span.set_attribute( span.set_attribute(
SpanAttributes.OUTPUT_VALUE, response_message.get("content", "") SpanAttributes.LLM_INVOCATION_PARAMETERS, _json_optional_params
) )
# This shows up under `output_messages` tab on the span page if optional_params.get("user"):
# This code assumes a single response span.set_attribute(SpanAttributes.USER_ID, optional_params.get("user"))
span.set_attribute(
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
response_message["role"],
)
span.set_attribute(
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
response_message.get("content", ""),
)
usage = response_obj.get("usage") #############################################
if usage: ########## LLM Response Attributes ##########
span.set_attribute( # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
SpanAttributes.LLM_TOKEN_COUNT_TOTAL, #############################################
usage.get("total_tokens"), for choice in response_obj.get("choices"):
) response_message = choice.get("message", {})
span.set_attribute(
SpanAttributes.OUTPUT_VALUE, response_message.get("content", "")
)
# The number of tokens used in the LLM response (completion). # This shows up under `output_messages` tab on the span page
span.set_attribute( # This code assumes a single response
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, span.set_attribute(
usage.get("completion_tokens"), f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
) response_message["role"],
)
span.set_attribute(
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
response_message.get("content", ""),
)
# The number of tokens used in the LLM prompt. usage = response_obj.get("usage")
span.set_attribute( if usage:
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, span.set_attribute(
usage.get("prompt_tokens"), SpanAttributes.LLM_TOKEN_COUNT_TOTAL,
) usage.get("total_tokens"),
pass )
# The number of tokens used in the LLM response (completion).
span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION,
usage.get("completion_tokens"),
)
# The number of tokens used in the LLM prompt.
span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_PROMPT,
usage.get("prompt_tokens"),
)
pass
except Exception as e:
verbose_proxy_logger.error(f"Error setting arize attributes: {e}")

View file

@ -6,5 +6,5 @@ model_list:
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings: litellm_settings:
callbacks: ["otel"] callbacks: ["arize"]

View file

@ -8,7 +8,7 @@ from dotenv import load_dotenv
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig
load_dotenv() load_dotenv()
@ -18,6 +18,9 @@ import logging
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_async_otel_callback(): async def test_async_otel_callback():
litellm.set_verbose = True litellm.set_verbose = True
verbose_proxy_logger.setLevel(logging.DEBUG)
verbose_logger.setLevel(logging.DEBUG)
litellm.success_callback = ["arize"] litellm.success_callback = ["arize"]
await litellm.acompletion( await litellm.acompletion(
@ -27,3 +30,5 @@ async def test_async_otel_callback():
temperature=0.1, temperature=0.1,
user="OTEL_USER", user="OTEL_USER",
) )
await asyncio.sleep(2)