forked from phoenix/litellm-mirror
(feat) Arize - Allow using Arize HTTP endpoint (#6364)
* arize use helper for get_arize_opentelemetry_config * use helper to get Arize OTEL config * arize add helpers for arize * docs allow using arize http endpoint * fix importing OTEL for Arize * use static methods for ArizeLogger * fix ArizeLogger tests
This commit is contained in:
parent
f943410e32
commit
b75019c1a5
7 changed files with 257 additions and 124 deletions
|
@ -62,7 +62,8 @@ litellm_settings:
|
|||
environment_variables:
|
||||
ARIZE_SPACE_KEY: "d0*****"
|
||||
ARIZE_API_KEY: "141a****"
|
||||
ARIZE_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize api endpoint
|
||||
ARIZE_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize GRPC api endpoint
|
||||
ARIZE_HTTP_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize HTTP api endpoint. Set either this or ARIZE_ENDPOINT
|
||||
```
|
||||
|
||||
## Support & Talk to Founders
|
||||
|
|
|
@ -1279,7 +1279,8 @@ litellm_settings:
|
|||
environment_variables:
|
||||
ARIZE_SPACE_KEY: "d0*****"
|
||||
ARIZE_API_KEY: "141a****"
|
||||
ARIZE_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize api endpoint
|
||||
ARIZE_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize GRPC api endpoint
|
||||
ARIZE_HTTP_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize HTTP api endpoint. Set either this or ARIZE_ENDPOINT
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
|
|
|
@ -7,135 +7,208 @@ this file has Arize ai specific helper functions
|
|||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
|
||||
|
||||
Span = _Span
|
||||
OpenTelemetryConfig = _OpenTelemetryConfig
|
||||
else:
|
||||
Span = Any
|
||||
OpenTelemetryConfig = Any
|
||||
|
||||
import os
|
||||
|
||||
from litellm.types.integrations.arize import *
|
||||
|
||||
|
||||
def make_json_serializable(payload: dict) -> dict:
|
||||
for key, value in payload.items():
|
||||
class ArizeLogger:
|
||||
@staticmethod
|
||||
def set_arize_ai_attributes(span: Span, kwargs, response_obj):
|
||||
from litellm.integrations._types.open_inference import (
|
||||
MessageAttributes,
|
||||
MessageContentAttributes,
|
||||
OpenInferenceSpanKindValues,
|
||||
SpanAttributes,
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(value, dict):
|
||||
# recursively sanitize dicts
|
||||
payload[key] = make_json_serializable(value.copy())
|
||||
elif not isinstance(value, (str, int, float, bool, type(None))):
|
||||
# everything else becomes a string
|
||||
payload[key] = str(value)
|
||||
except Exception:
|
||||
# non blocking if it can't cast to a str
|
||||
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
# litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
|
||||
#############################################
|
||||
############ LLM CALL METADATA ##############
|
||||
#############################################
|
||||
# commented out for now - looks like Arize AI could not log this
|
||||
# metadata = litellm_params.get("metadata", {}) or {}
|
||||
# span.set_attribute(SpanAttributes.METADATA, str(metadata))
|
||||
|
||||
#############################################
|
||||
########## LLM Request Attributes ###########
|
||||
#############################################
|
||||
|
||||
# The name of the LLM a request is being made to
|
||||
if kwargs.get("model"):
|
||||
span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model"))
|
||||
|
||||
span.set_attribute(
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND,
|
||||
OpenInferenceSpanKindValues.LLM.value,
|
||||
)
|
||||
messages = kwargs.get("messages")
|
||||
|
||||
# for /chat/completions
|
||||
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
|
||||
if messages:
|
||||
span.set_attribute(
|
||||
SpanAttributes.INPUT_VALUE,
|
||||
messages[-1].get("content", ""), # get the last message for input
|
||||
)
|
||||
|
||||
# LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page
|
||||
for idx, msg in enumerate(messages):
|
||||
# Set the role per message
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}",
|
||||
msg["role"],
|
||||
)
|
||||
# Set the content per message
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
msg.get("content", ""),
|
||||
)
|
||||
|
||||
# The Generative AI Provider: Azure, OpenAI, etc.
|
||||
_optional_params = ArizeLogger.make_json_serializable(optional_params)
|
||||
_json_optional_params = json.dumps(_optional_params)
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_INVOCATION_PARAMETERS, _json_optional_params
|
||||
)
|
||||
|
||||
if optional_params.get("user"):
|
||||
span.set_attribute(SpanAttributes.USER_ID, optional_params.get("user"))
|
||||
|
||||
#############################################
|
||||
########## LLM Response Attributes ##########
|
||||
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
|
||||
#############################################
|
||||
for choice in response_obj.get("choices"):
|
||||
response_message = choice.get("message", {})
|
||||
span.set_attribute(
|
||||
SpanAttributes.OUTPUT_VALUE, response_message.get("content", "")
|
||||
)
|
||||
|
||||
# This shows up under `output_messages` tab on the span page
|
||||
# This code assumes a single response
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
|
||||
response_message["role"],
|
||||
)
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
|
||||
usage = response_obj.get("usage")
|
||||
if usage:
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_TOTAL,
|
||||
usage.get("total_tokens"),
|
||||
)
|
||||
|
||||
# The number of tokens used in the LLM response (completion).
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION,
|
||||
usage.get("completion_tokens"),
|
||||
)
|
||||
|
||||
# The number of tokens used in the LLM prompt.
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_PROMPT,
|
||||
usage.get("prompt_tokens"),
|
||||
)
|
||||
pass
|
||||
return payload
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error setting arize attributes: {e}")
|
||||
|
||||
###################### Helper functions ######################
|
||||
|
||||
def set_arize_ai_attributes(span: Span, kwargs, response_obj):
|
||||
from litellm.integrations._types.open_inference import (
|
||||
MessageAttributes,
|
||||
MessageContentAttributes,
|
||||
OpenInferenceSpanKindValues,
|
||||
SpanAttributes,
|
||||
)
|
||||
@staticmethod
|
||||
def _get_arize_config() -> ArizeConfig:
|
||||
"""
|
||||
Helper function to get Arize configuration.
|
||||
|
||||
try:
|
||||
Returns:
|
||||
ArizeConfig: A Pydantic model containing Arize configuration.
|
||||
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
# litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
Raises:
|
||||
ValueError: If required environment variables are not set.
|
||||
"""
|
||||
space_key = os.environ.get("ARIZE_SPACE_KEY")
|
||||
api_key = os.environ.get("ARIZE_API_KEY")
|
||||
|
||||
#############################################
|
||||
############ LLM CALL METADATA ##############
|
||||
#############################################
|
||||
# commented out for now - looks like Arize AI could not log this
|
||||
# metadata = litellm_params.get("metadata", {}) or {}
|
||||
# span.set_attribute(SpanAttributes.METADATA, str(metadata))
|
||||
if not space_key:
|
||||
raise ValueError("ARIZE_SPACE_KEY not found in environment variables")
|
||||
if not api_key:
|
||||
raise ValueError("ARIZE_API_KEY not found in environment variables")
|
||||
|
||||
#############################################
|
||||
########## LLM Request Attributes ###########
|
||||
#############################################
|
||||
|
||||
# The name of the LLM a request is being made to
|
||||
if kwargs.get("model"):
|
||||
span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model"))
|
||||
|
||||
span.set_attribute(
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND,
|
||||
OpenInferenceSpanKindValues.LLM.value,
|
||||
)
|
||||
messages = kwargs.get("messages")
|
||||
|
||||
# for /chat/completions
|
||||
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
|
||||
if messages:
|
||||
span.set_attribute(
|
||||
SpanAttributes.INPUT_VALUE,
|
||||
messages[-1].get("content", ""), # get the last message for input
|
||||
grpc_endpoint = os.environ.get("ARIZE_ENDPOINT")
|
||||
http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT")
|
||||
if grpc_endpoint is None and http_endpoint is None:
|
||||
# use default arize grpc endpoint
|
||||
verbose_logger.debug(
|
||||
"No ARIZE_ENDPOINT or ARIZE_HTTP_ENDPOINT found, using default endpoint: https://otlp.arize.com/v1"
|
||||
)
|
||||
grpc_endpoint = "https://otlp.arize.com/v1"
|
||||
|
||||
# LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page
|
||||
for idx, msg in enumerate(messages):
|
||||
# Set the role per message
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}",
|
||||
msg["role"],
|
||||
)
|
||||
# Set the content per message
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
msg.get("content", ""),
|
||||
)
|
||||
|
||||
# The Generative AI Provider: Azure, OpenAI, etc.
|
||||
_optional_params = make_json_serializable(optional_params)
|
||||
_json_optional_params = json.dumps(_optional_params)
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_INVOCATION_PARAMETERS, _json_optional_params
|
||||
return ArizeConfig(
|
||||
space_key=space_key,
|
||||
api_key=api_key,
|
||||
grpc_endpoint=grpc_endpoint,
|
||||
http_endpoint=http_endpoint,
|
||||
)
|
||||
|
||||
if optional_params.get("user"):
|
||||
span.set_attribute(SpanAttributes.USER_ID, optional_params.get("user"))
|
||||
@staticmethod
|
||||
def get_arize_opentelemetry_config() -> Optional[OpenTelemetryConfig]:
|
||||
"""
|
||||
Helper function to get OpenTelemetry configuration for Arize.
|
||||
|
||||
#############################################
|
||||
########## LLM Response Attributes ##########
|
||||
# https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions
|
||||
#############################################
|
||||
for choice in response_obj.get("choices"):
|
||||
response_message = choice.get("message", {})
|
||||
span.set_attribute(
|
||||
SpanAttributes.OUTPUT_VALUE, response_message.get("content", "")
|
||||
Args:
|
||||
arize_config (ArizeConfig): Arize configuration object.
|
||||
|
||||
Returns:
|
||||
OpenTelemetryConfig: Configuration for OpenTelemetry.
|
||||
"""
|
||||
from .opentelemetry import OpenTelemetryConfig
|
||||
|
||||
arize_config = ArizeLogger._get_arize_config()
|
||||
if arize_config.http_endpoint:
|
||||
return OpenTelemetryConfig(
|
||||
exporter="otlp_http",
|
||||
endpoint=arize_config.http_endpoint,
|
||||
)
|
||||
|
||||
# This shows up under `output_messages` tab on the span page
|
||||
# This code assumes a single response
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}",
|
||||
response_message["role"],
|
||||
)
|
||||
span.set_attribute(
|
||||
f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}",
|
||||
response_message.get("content", ""),
|
||||
)
|
||||
# use default arize grpc endpoint
|
||||
return OpenTelemetryConfig(
|
||||
exporter="otlp_grpc",
|
||||
endpoint=arize_config.grpc_endpoint,
|
||||
)
|
||||
|
||||
usage = response_obj.get("usage")
|
||||
if usage:
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_TOTAL,
|
||||
usage.get("total_tokens"),
|
||||
)
|
||||
|
||||
# The number of tokens used in the LLM response (completion).
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION,
|
||||
usage.get("completion_tokens"),
|
||||
)
|
||||
|
||||
# The number of tokens used in the LLM prompt.
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_PROMPT,
|
||||
usage.get("prompt_tokens"),
|
||||
)
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error setting arize attributes: {e}")
|
||||
@staticmethod
|
||||
def make_json_serializable(payload: dict) -> dict:
|
||||
for key, value in payload.items():
|
||||
try:
|
||||
if isinstance(value, dict):
|
||||
# recursively sanitize dicts
|
||||
payload[key] = ArizeLogger.make_json_serializable(value.copy())
|
||||
elif not isinstance(value, (str, int, float, bool, type(None))):
|
||||
# everything else becomes a string
|
||||
payload[key] = str(value)
|
||||
except Exception:
|
||||
# non blocking if it can't cast to a str
|
||||
pass
|
||||
return payload
|
||||
|
|
|
@ -396,9 +396,9 @@ class OpenTelemetry(CustomLogger):
|
|||
def set_attributes(self, span: Span, kwargs, response_obj): # noqa: PLR0915
|
||||
try:
|
||||
if self.callback_name == "arize":
|
||||
from litellm.integrations.arize_ai import set_arize_ai_attributes
|
||||
from litellm.integrations.arize_ai import ArizeLogger
|
||||
|
||||
set_arize_ai_attributes(span, kwargs, response_obj)
|
||||
ArizeLogger.set_arize_ai_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
elif self.callback_name == "langtrace":
|
||||
from litellm.integrations.langtrace import LangtraceAttributes
|
||||
|
|
|
@ -60,6 +60,7 @@ from litellm.utils import (
|
|||
|
||||
from ..integrations.aispend import AISpendLogger
|
||||
from ..integrations.argilla import ArgillaLogger
|
||||
from ..integrations.arize_ai import ArizeLogger
|
||||
from ..integrations.athina import AthinaLogger
|
||||
from ..integrations.braintrust_logging import BraintrustLogger
|
||||
from ..integrations.datadog.datadog import DataDogLogger
|
||||
|
@ -2323,22 +2324,16 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
_in_memory_loggers.append(_opik_logger)
|
||||
return _opik_logger # type: ignore
|
||||
elif logging_integration == "arize":
|
||||
if "ARIZE_SPACE_KEY" not in os.environ:
|
||||
raise ValueError("ARIZE_SPACE_KEY not found in environment variables")
|
||||
if "ARIZE_API_KEY" not in os.environ:
|
||||
raise ValueError("ARIZE_API_KEY not found in environment variables")
|
||||
from litellm.integrations.opentelemetry import (
|
||||
OpenTelemetry,
|
||||
OpenTelemetryConfig,
|
||||
)
|
||||
|
||||
arize_endpoint = (
|
||||
os.environ.get("ARIZE_ENDPOINT", None) or "https://otlp.arize.com/v1"
|
||||
)
|
||||
otel_config = OpenTelemetryConfig(
|
||||
exporter="otlp_grpc",
|
||||
endpoint=arize_endpoint,
|
||||
)
|
||||
otel_config = ArizeLogger.get_arize_opentelemetry_config()
|
||||
if otel_config is None:
|
||||
raise ValueError(
|
||||
"No valid endpoint found for Arize, please set 'ARIZE_ENDPOINT' to your GRPC endpoint or 'ARIZE_HTTP_ENDPOINT' to your HTTP endpoint"
|
||||
)
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"space_key={os.getenv('ARIZE_SPACE_KEY')},api_key={os.getenv('ARIZE_API_KEY')}"
|
||||
)
|
||||
|
@ -2351,7 +2346,6 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
_otel_logger = OpenTelemetry(config=otel_config, callback_name="arize")
|
||||
_in_memory_loggers.append(_otel_logger)
|
||||
return _otel_logger # type: ignore
|
||||
|
||||
elif logging_integration == "otel":
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
|
||||
|
|
10
litellm/types/integrations/arize.py
Normal file
10
litellm/types/integrations/arize.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ArizeConfig(BaseModel):
|
||||
space_key: str
|
||||
api_key: str
|
||||
grpc_endpoint: Optional[str] = None
|
||||
http_endpoint: Optional[str] = None
|
|
@ -10,9 +10,9 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanE
|
|||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig
|
||||
from litellm.integrations.arize_ai import ArizeConfig, ArizeLogger
|
||||
|
||||
load_dotenv()
|
||||
import logging
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
|
@ -32,3 +32,57 @@ async def test_async_otel_callback():
|
|||
)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env_vars(monkeypatch):
|
||||
monkeypatch.setenv("ARIZE_SPACE_KEY", "test_space_key")
|
||||
monkeypatch.setenv("ARIZE_API_KEY", "test_api_key")
|
||||
|
||||
|
||||
def test_get_arize_config(mock_env_vars):
|
||||
"""
|
||||
Use Arize default endpoint when no endpoints are provided
|
||||
"""
|
||||
config = ArizeLogger._get_arize_config()
|
||||
assert isinstance(config, ArizeConfig)
|
||||
assert config.space_key == "test_space_key"
|
||||
assert config.api_key == "test_api_key"
|
||||
assert config.grpc_endpoint == "https://otlp.arize.com/v1"
|
||||
assert config.http_endpoint is None
|
||||
|
||||
|
||||
def test_get_arize_config_with_endpoints(mock_env_vars, monkeypatch):
|
||||
"""
|
||||
Use provided endpoints when they are set
|
||||
"""
|
||||
monkeypatch.setenv("ARIZE_ENDPOINT", "grpc://test.endpoint")
|
||||
monkeypatch.setenv("ARIZE_HTTP_ENDPOINT", "http://test.endpoint")
|
||||
|
||||
config = ArizeLogger._get_arize_config()
|
||||
assert config.grpc_endpoint == "grpc://test.endpoint"
|
||||
assert config.http_endpoint == "http://test.endpoint"
|
||||
|
||||
|
||||
def test_get_arize_opentelemetry_config_grpc(mock_env_vars, monkeypatch):
|
||||
"""
|
||||
Use provided GRPC endpoint when it is set
|
||||
"""
|
||||
monkeypatch.setenv("ARIZE_ENDPOINT", "grpc://test.endpoint")
|
||||
|
||||
config = ArizeLogger.get_arize_opentelemetry_config()
|
||||
assert isinstance(config, OpenTelemetryConfig)
|
||||
assert config.exporter == "otlp_grpc"
|
||||
assert config.endpoint == "grpc://test.endpoint"
|
||||
|
||||
|
||||
def test_get_arize_opentelemetry_config_http(mock_env_vars, monkeypatch):
|
||||
"""
|
||||
Use provided HTTP endpoint when it is set
|
||||
"""
|
||||
monkeypatch.setenv("ARIZE_HTTP_ENDPOINT", "http://test.endpoint")
|
||||
|
||||
config = ArizeLogger.get_arize_opentelemetry_config()
|
||||
assert isinstance(config, OpenTelemetryConfig)
|
||||
assert config.exporter == "otlp_http"
|
||||
assert config.endpoint == "http://test.endpoint"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue