From b75019c1a502f9af60f8387effb822763233a5d4 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 23 Oct 2024 09:38:35 +0530 Subject: [PATCH] (feat) Arize - Allow using Arize HTTP endpoint (#6364) * arize use helper for get_arize_opentelemetry_config * use helper to get Arize OTEL config * arize add helpers for arize * docs allow using arize http endpoint * fix importing OTEL for Arize * use static methods for ArizeLogger * fix ArizeLogger tests --- .../docs/observability/arize_integration.md | 3 +- docs/my-website/docs/proxy/logging.md | 3 +- litellm/integrations/arize_ai.py | 287 +++++++++++------- litellm/integrations/opentelemetry.py | 4 +- litellm/litellm_core_utils/litellm_logging.py | 18 +- litellm/types/integrations/arize.py | 10 + tests/local_testing/test_arize_ai.py | 56 +++- 7 files changed, 257 insertions(+), 124 deletions(-) create mode 100644 litellm/types/integrations/arize.py diff --git a/docs/my-website/docs/observability/arize_integration.md b/docs/my-website/docs/observability/arize_integration.md index 17be003f8..a69d32e5b 100644 --- a/docs/my-website/docs/observability/arize_integration.md +++ b/docs/my-website/docs/observability/arize_integration.md @@ -62,7 +62,8 @@ litellm_settings: environment_variables: ARIZE_SPACE_KEY: "d0*****" ARIZE_API_KEY: "141a****" - ARIZE_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize api endpoint + ARIZE_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize GRPC api endpoint + ARIZE_HTTP_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize HTTP api endpoint. Set either this or ARIZE_ENDPOINT ``` ## Support & Talk to Founders diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index 3a764005f..72c2e3773 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -1279,7 +1279,8 @@ litellm_settings: environment_variables: ARIZE_SPACE_KEY: "d0*****" ARIZE_API_KEY: "141a****" - ARIZE_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize api endpoint + ARIZE_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize GRPC api endpoint + ARIZE_HTTP_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize HTTP api endpoint. Set either this or ARIZE_ENDPOINT ``` 2. Start Proxy diff --git a/litellm/integrations/arize_ai.py b/litellm/integrations/arize_ai.py index 5a66cfd0c..acd3f745b 100644 --- a/litellm/integrations/arize_ai.py +++ b/litellm/integrations/arize_ai.py @@ -7,135 +7,208 @@ this file has Arize ai specific helper functions import json from typing import TYPE_CHECKING, Any, Optional, Union -from litellm._logging import verbose_proxy_logger +from litellm._logging import verbose_logger if TYPE_CHECKING: from opentelemetry.trace import Span as _Span + from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig + Span = _Span + OpenTelemetryConfig = _OpenTelemetryConfig else: Span = Any + OpenTelemetryConfig = Any + +import os + +from litellm.types.integrations.arize import * -def make_json_serializable(payload: dict) -> dict: - for key, value in payload.items(): +class ArizeLogger: + @staticmethod + def set_arize_ai_attributes(span: Span, kwargs, response_obj): + from litellm.integrations._types.open_inference import ( + MessageAttributes, + MessageContentAttributes, + OpenInferenceSpanKindValues, + SpanAttributes, + ) + try: - if isinstance(value, dict): - # recursively sanitize dicts - payload[key] = make_json_serializable(value.copy()) - elif not isinstance(value, (str, int, float, bool, type(None))): - # everything else becomes a string - payload[key] = str(value) - except Exception: - # non blocking if it can't cast to a str + + optional_params = kwargs.get("optional_params", {}) + # litellm_params = kwargs.get("litellm_params", {}) or {} + + ############################################# + ############ LLM CALL METADATA ############## + ############################################# + # commented out for now - looks like Arize AI could not log this + # metadata = litellm_params.get("metadata", {}) or {} + # span.set_attribute(SpanAttributes.METADATA, str(metadata)) + + ############################################# + ########## LLM Request Attributes ########### + ############################################# + + # The name of the LLM a request is being made to + if kwargs.get("model"): + span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model")) + + span.set_attribute( + SpanAttributes.OPENINFERENCE_SPAN_KIND, + OpenInferenceSpanKindValues.LLM.value, + ) + messages = kwargs.get("messages") + + # for /chat/completions + # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions + if messages: + span.set_attribute( + SpanAttributes.INPUT_VALUE, + messages[-1].get("content", ""), # get the last message for input + ) + + # LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page + for idx, msg in enumerate(messages): + # Set the role per message + span.set_attribute( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}", + msg["role"], + ) + # Set the content per message + span.set_attribute( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}", + msg.get("content", ""), + ) + + # The Generative AI Provider: Azure, OpenAI, etc. + _optional_params = ArizeLogger.make_json_serializable(optional_params) + _json_optional_params = json.dumps(_optional_params) + span.set_attribute( + SpanAttributes.LLM_INVOCATION_PARAMETERS, _json_optional_params + ) + + if optional_params.get("user"): + span.set_attribute(SpanAttributes.USER_ID, optional_params.get("user")) + + ############################################# + ########## LLM Response Attributes ########## + # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions + ############################################# + for choice in response_obj.get("choices"): + response_message = choice.get("message", {}) + span.set_attribute( + SpanAttributes.OUTPUT_VALUE, response_message.get("content", "") + ) + + # This shows up under `output_messages` tab on the span page + # This code assumes a single response + span.set_attribute( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}", + response_message["role"], + ) + span.set_attribute( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}", + response_message.get("content", ""), + ) + + usage = response_obj.get("usage") + if usage: + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_TOTAL, + usage.get("total_tokens"), + ) + + # The number of tokens used in the LLM response (completion). + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, + usage.get("completion_tokens"), + ) + + # The number of tokens used in the LLM prompt. + span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_PROMPT, + usage.get("prompt_tokens"), + ) pass - return payload + except Exception as e: + verbose_logger.error(f"Error setting arize attributes: {e}") + ###################### Helper functions ###################### -def set_arize_ai_attributes(span: Span, kwargs, response_obj): - from litellm.integrations._types.open_inference import ( - MessageAttributes, - MessageContentAttributes, - OpenInferenceSpanKindValues, - SpanAttributes, - ) + @staticmethod + def _get_arize_config() -> ArizeConfig: + """ + Helper function to get Arize configuration. - try: + Returns: + ArizeConfig: A Pydantic model containing Arize configuration. - optional_params = kwargs.get("optional_params", {}) - # litellm_params = kwargs.get("litellm_params", {}) or {} + Raises: + ValueError: If required environment variables are not set. + """ + space_key = os.environ.get("ARIZE_SPACE_KEY") + api_key = os.environ.get("ARIZE_API_KEY") - ############################################# - ############ LLM CALL METADATA ############## - ############################################# - # commented out for now - looks like Arize AI could not log this - # metadata = litellm_params.get("metadata", {}) or {} - # span.set_attribute(SpanAttributes.METADATA, str(metadata)) + if not space_key: + raise ValueError("ARIZE_SPACE_KEY not found in environment variables") + if not api_key: + raise ValueError("ARIZE_API_KEY not found in environment variables") - ############################################# - ########## LLM Request Attributes ########### - ############################################# - - # The name of the LLM a request is being made to - if kwargs.get("model"): - span.set_attribute(SpanAttributes.LLM_MODEL_NAME, kwargs.get("model")) - - span.set_attribute( - SpanAttributes.OPENINFERENCE_SPAN_KIND, - OpenInferenceSpanKindValues.LLM.value, - ) - messages = kwargs.get("messages") - - # for /chat/completions - # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions - if messages: - span.set_attribute( - SpanAttributes.INPUT_VALUE, - messages[-1].get("content", ""), # get the last message for input + grpc_endpoint = os.environ.get("ARIZE_ENDPOINT") + http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT") + if grpc_endpoint is None and http_endpoint is None: + # use default arize grpc endpoint + verbose_logger.debug( + "No ARIZE_ENDPOINT or ARIZE_HTTP_ENDPOINT found, using default endpoint: https://otlp.arize.com/v1" ) + grpc_endpoint = "https://otlp.arize.com/v1" - # LLM_INPUT_MESSAGES shows up under `input_messages` tab on the span page - for idx, msg in enumerate(messages): - # Set the role per message - span.set_attribute( - f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}", - msg["role"], - ) - # Set the content per message - span.set_attribute( - f"{SpanAttributes.LLM_INPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}", - msg.get("content", ""), - ) - - # The Generative AI Provider: Azure, OpenAI, etc. - _optional_params = make_json_serializable(optional_params) - _json_optional_params = json.dumps(_optional_params) - span.set_attribute( - SpanAttributes.LLM_INVOCATION_PARAMETERS, _json_optional_params + return ArizeConfig( + space_key=space_key, + api_key=api_key, + grpc_endpoint=grpc_endpoint, + http_endpoint=http_endpoint, ) - if optional_params.get("user"): - span.set_attribute(SpanAttributes.USER_ID, optional_params.get("user")) + @staticmethod + def get_arize_opentelemetry_config() -> Optional[OpenTelemetryConfig]: + """ + Helper function to get OpenTelemetry configuration for Arize. - ############################################# - ########## LLM Response Attributes ########## - # https://docs.arize.com/arize/large-language-models/tracing/semantic-conventions - ############################################# - for choice in response_obj.get("choices"): - response_message = choice.get("message", {}) - span.set_attribute( - SpanAttributes.OUTPUT_VALUE, response_message.get("content", "") + Args: + arize_config (ArizeConfig): Arize configuration object. + + Returns: + OpenTelemetryConfig: Configuration for OpenTelemetry. + """ + from .opentelemetry import OpenTelemetryConfig + + arize_config = ArizeLogger._get_arize_config() + if arize_config.http_endpoint: + return OpenTelemetryConfig( + exporter="otlp_http", + endpoint=arize_config.http_endpoint, ) - # This shows up under `output_messages` tab on the span page - # This code assumes a single response - span.set_attribute( - f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_ROLE}", - response_message["role"], - ) - span.set_attribute( - f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENT}", - response_message.get("content", ""), - ) + # use default arize grpc endpoint + return OpenTelemetryConfig( + exporter="otlp_grpc", + endpoint=arize_config.grpc_endpoint, + ) - usage = response_obj.get("usage") - if usage: - span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_TOTAL, - usage.get("total_tokens"), - ) - - # The number of tokens used in the LLM response (completion). - span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, - usage.get("completion_tokens"), - ) - - # The number of tokens used in the LLM prompt. - span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_PROMPT, - usage.get("prompt_tokens"), - ) - pass - except Exception as e: - verbose_proxy_logger.error(f"Error setting arize attributes: {e}") + @staticmethod + def make_json_serializable(payload: dict) -> dict: + for key, value in payload.items(): + try: + if isinstance(value, dict): + # recursively sanitize dicts + payload[key] = ArizeLogger.make_json_serializable(value.copy()) + elif not isinstance(value, (str, int, float, bool, type(None))): + # everything else becomes a string + payload[key] = str(value) + except Exception: + # non blocking if it can't cast to a str + pass + return payload diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 171ec21e7..8ba871acc 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -396,9 +396,9 @@ class OpenTelemetry(CustomLogger): def set_attributes(self, span: Span, kwargs, response_obj): # noqa: PLR0915 try: if self.callback_name == "arize": - from litellm.integrations.arize_ai import set_arize_ai_attributes + from litellm.integrations.arize_ai import ArizeLogger - set_arize_ai_attributes(span, kwargs, response_obj) + ArizeLogger.set_arize_ai_attributes(span, kwargs, response_obj) return elif self.callback_name == "langtrace": from litellm.integrations.langtrace import LangtraceAttributes diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index f41ac256b..7aee38151 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -60,6 +60,7 @@ from litellm.utils import ( from ..integrations.aispend import AISpendLogger from ..integrations.argilla import ArgillaLogger +from ..integrations.arize_ai import ArizeLogger from ..integrations.athina import AthinaLogger from ..integrations.braintrust_logging import BraintrustLogger from ..integrations.datadog.datadog import DataDogLogger @@ -2323,22 +2324,16 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_opik_logger) return _opik_logger # type: ignore elif logging_integration == "arize": - if "ARIZE_SPACE_KEY" not in os.environ: - raise ValueError("ARIZE_SPACE_KEY not found in environment variables") - if "ARIZE_API_KEY" not in os.environ: - raise ValueError("ARIZE_API_KEY not found in environment variables") from litellm.integrations.opentelemetry import ( OpenTelemetry, OpenTelemetryConfig, ) - arize_endpoint = ( - os.environ.get("ARIZE_ENDPOINT", None) or "https://otlp.arize.com/v1" - ) - otel_config = OpenTelemetryConfig( - exporter="otlp_grpc", - endpoint=arize_endpoint, - ) + otel_config = ArizeLogger.get_arize_opentelemetry_config() + if otel_config is None: + raise ValueError( + "No valid endpoint found for Arize, please set 'ARIZE_ENDPOINT' to your GRPC endpoint or 'ARIZE_HTTP_ENDPOINT' to your HTTP endpoint" + ) os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( f"space_key={os.getenv('ARIZE_SPACE_KEY')},api_key={os.getenv('ARIZE_API_KEY')}" ) @@ -2351,7 +2346,6 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _otel_logger = OpenTelemetry(config=otel_config, callback_name="arize") _in_memory_loggers.append(_otel_logger) return _otel_logger # type: ignore - elif logging_integration == "otel": from litellm.integrations.opentelemetry import OpenTelemetry diff --git a/litellm/types/integrations/arize.py b/litellm/types/integrations/arize.py new file mode 100644 index 000000000..3c0bbcde0 --- /dev/null +++ b/litellm/types/integrations/arize.py @@ -0,0 +1,10 @@ +from typing import Optional + +from pydantic import BaseModel + + +class ArizeConfig(BaseModel): + space_key: str + api_key: str + grpc_endpoint: Optional[str] = None + http_endpoint: Optional[str] = None diff --git a/tests/local_testing/test_arize_ai.py b/tests/local_testing/test_arize_ai.py index f6ccc75f2..24aed3da7 100644 --- a/tests/local_testing/test_arize_ai.py +++ b/tests/local_testing/test_arize_ai.py @@ -10,9 +10,9 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanE import litellm from litellm._logging import verbose_logger, verbose_proxy_logger from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig +from litellm.integrations.arize_ai import ArizeConfig, ArizeLogger load_dotenv() -import logging @pytest.mark.asyncio() @@ -32,3 +32,57 @@ async def test_async_otel_callback(): ) await asyncio.sleep(2) + + +@pytest.fixture +def mock_env_vars(monkeypatch): + monkeypatch.setenv("ARIZE_SPACE_KEY", "test_space_key") + monkeypatch.setenv("ARIZE_API_KEY", "test_api_key") + + +def test_get_arize_config(mock_env_vars): + """ + Use Arize default endpoint when no endpoints are provided + """ + config = ArizeLogger._get_arize_config() + assert isinstance(config, ArizeConfig) + assert config.space_key == "test_space_key" + assert config.api_key == "test_api_key" + assert config.grpc_endpoint == "https://otlp.arize.com/v1" + assert config.http_endpoint is None + + +def test_get_arize_config_with_endpoints(mock_env_vars, monkeypatch): + """ + Use provided endpoints when they are set + """ + monkeypatch.setenv("ARIZE_ENDPOINT", "grpc://test.endpoint") + monkeypatch.setenv("ARIZE_HTTP_ENDPOINT", "http://test.endpoint") + + config = ArizeLogger._get_arize_config() + assert config.grpc_endpoint == "grpc://test.endpoint" + assert config.http_endpoint == "http://test.endpoint" + + +def test_get_arize_opentelemetry_config_grpc(mock_env_vars, monkeypatch): + """ + Use provided GRPC endpoint when it is set + """ + monkeypatch.setenv("ARIZE_ENDPOINT", "grpc://test.endpoint") + + config = ArizeLogger.get_arize_opentelemetry_config() + assert isinstance(config, OpenTelemetryConfig) + assert config.exporter == "otlp_grpc" + assert config.endpoint == "grpc://test.endpoint" + + +def test_get_arize_opentelemetry_config_http(mock_env_vars, monkeypatch): + """ + Use provided HTTP endpoint when it is set + """ + monkeypatch.setenv("ARIZE_HTTP_ENDPOINT", "http://test.endpoint") + + config = ArizeLogger.get_arize_opentelemetry_config() + assert isinstance(config, OpenTelemetryConfig) + assert config.exporter == "otlp_http" + assert config.endpoint == "http://test.endpoint"