From 8bbbfaba5dd23512865f0d221d52a617c10a928d Mon Sep 17 00:00:00 2001 From: Emilio Garcia Date: Tue, 11 Nov 2025 13:26:25 -0500 Subject: [PATCH] fix(telemetry): Safety APIs Use OpenTelemetry Natively This change created a standardized way to handle telemetry internally. All custom names that are not a semantic convention are maintained in constants.py. Helper functions to capture custom telemetry data not captured by automatic instrumentation are in helpers.py. --- src/llama_stack/core/routers/safety.py | 11 +++++- src/llama_stack/core/telemetry/__init__.py | 27 -------------- src/llama_stack/core/telemetry/constants.py | 17 +++++++++ src/llama_stack/core/telemetry/helpers.py | 36 +++++++++++++++++++ .../inline/agents/meta_reference/safety.py | 16 ++++----- 5 files changed, 69 insertions(+), 38 deletions(-) create mode 100644 src/llama_stack/core/telemetry/constants.py create mode 100644 src/llama_stack/core/telemetry/helpers.py diff --git a/src/llama_stack/core/routers/safety.py b/src/llama_stack/core/routers/safety.py index 2bc99f14f..ee9dca8c5 100644 --- a/src/llama_stack/core/routers/safety.py +++ b/src/llama_stack/core/routers/safety.py @@ -7,10 +7,15 @@ from typing import Any from llama_stack.core.datatypes import SafetyConfig +from llama_stack.core.telemetry.constants import GUARDRAIL_SPAN_NAME +from llama_stack.core.telemetry.helpers import guardrail_request_span_attributes from llama_stack.log import get_logger from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield +from opentelemetry import trace + logger = get_logger(name=__name__, category="core::routers") +tracer = trace.get_tracer(__name__) class SafetyRouter(Safety): @@ -45,6 +50,7 @@ class SafetyRouter(Safety): logger.debug(f"SafetyRouter.unregister_shield: {identifier}") return await self.routing_table.unregister_shield(identifier) + @tracer.start_as_current_span(name=GUARDRAIL_SPAN_NAME) async def run_shield( self, shield_id: str, @@ -53,12 +59,15 @@ class SafetyRouter(Safety): ) -> RunShieldResponse: logger.debug(f"SafetyRouter.run_shield: {shield_id}") provider = await self.routing_table.get_provider_impl(shield_id) - return await provider.run_shield( + response = await provider.run_shield( shield_id=shield_id, messages=messages, params=params, ) + guardrail_request_span_attributes(shield_id, messages, response) + return response + async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject: list_shields_response = await self.routing_table.list_shields() shields = list_shields_response.data diff --git a/src/llama_stack/core/telemetry/__init__.py b/src/llama_stack/core/telemetry/__init__.py index bab612c0d..756f351d8 100644 --- a/src/llama_stack/core/telemetry/__init__.py +++ b/src/llama_stack/core/telemetry/__init__.py @@ -3,30 +3,3 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -from .telemetry import Telemetry -from .trace_protocol import serialize_value, trace_protocol -from .tracing import ( - CURRENT_TRACE_CONTEXT, - ROOT_SPAN_MARKERS, - end_trace, - enqueue_event, - get_current_span, - setup_logger, - span, - start_trace, -) - -__all__ = [ - "Telemetry", - "trace_protocol", - "serialize_value", - "CURRENT_TRACE_CONTEXT", - "ROOT_SPAN_MARKERS", - "end_trace", - "enqueue_event", - "get_current_span", - "setup_logger", - "span", - "start_trace", -] diff --git a/src/llama_stack/core/telemetry/constants.py b/src/llama_stack/core/telemetry/constants.py new file mode 100644 index 000000000..2a7051dc3 --- /dev/null +++ b/src/llama_stack/core/telemetry/constants.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Safety Attributes +GUARDRAIL_SPAN_NAME = "llama_stack.guardrail" + +SAFETY_REQUEST_PREFIX = "llama_stack.safety.request" +SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.shield_id" +SAFETY_REQUEST_MESSAGES_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.messages" + +SAFETY_RESPONSE_PREFIX = "llama_stack.safety.response" +SAFETY_RESPONSE_METADATA_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.metadata" +SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.level" +SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.user_message" diff --git a/src/llama_stack/core/telemetry/helpers.py b/src/llama_stack/core/telemetry/helpers.py new file mode 100644 index 000000000..f7759e205 --- /dev/null +++ b/src/llama_stack/core/telemetry/helpers.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json + +from opentelemetry import trace + +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import RunShieldResponse + +from .constants import ( + SAFETY_REQUEST_MESSAGES_ATTRIBUTE, + SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, + SAFETY_RESPONSE_METADATA_ATTRIBUTE, + SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, + SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, +) + + +def guardrail_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None: + span = trace.get_current_span() + if span is not None: + span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id) + messages_json = json.dumps([msg.model_dump() for msg in messages]) + span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json) + + if response.violation: + if response.violation.metadata: + metadata_json = json.dumps(response.violation.metadata) + span.set_attribute(SAFETY_RESPONSE_METADATA_ATTRIBUTE, metadata_json) + if response.violation.user_message: + span.set_attribute(SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, response.violation.user_message) + span.set_attribute(SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, response.violation.violation_level.value) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/safety.py b/src/llama_stack/providers/inline/agents/meta_reference/safety.py index bfb557a99..123a2e283 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -6,7 +6,6 @@ import asyncio -from llama_stack.core.telemetry import tracing from llama_stack.log import get_logger from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel @@ -31,15 +30,12 @@ class ShieldRunnerMixin: self.output_shields = output_shields async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None: - async def run_shield_with_span(identifier: str): - async with tracing.span(f"run_shield_{identifier}"): - return await self.safety_api.run_shield( - shield_id=identifier, - messages=messages, - params={}, - ) - - responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers]) + responses = await asyncio.gather( + *[ + self.safety_api.run_shield(shield_id=identifier, messages=messages, params={}) + for identifier in identifiers + ] + ) for identifier, response in zip(identifiers, responses, strict=False): if not response.violation: continue