fix(telemetry): Safety APIs Use OpenTelemetry Natively

This change created a standardized way to handle telemetry internally.
All custom names that are not a semantic convention are maintained in
constants.py. Helper functions to capture custom telemetry data not
captured by automatic instrumentation are in helpers.py.
This commit is contained in:
Emilio Garcia 2025-11-11 13:26:25 -05:00
parent f0646ab0f6
commit 89b0c69a07
5 changed files with 69 additions and 38 deletions

View file

@ -7,10 +7,15 @@
from typing import Any from typing import Any
from llama_stack.core.datatypes import SafetyConfig from llama_stack.core.datatypes import SafetyConfig
from llama_stack.core.telemetry.constants import GUARDRAIL_SPAN_NAME
from llama_stack.core.telemetry.helpers import guardrail_request_span_attributes
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
from opentelemetry import trace
logger = get_logger(name=__name__, category="core::routers") logger = get_logger(name=__name__, category="core::routers")
tracer = trace.get_tracer(__name__)
class SafetyRouter(Safety): class SafetyRouter(Safety):
@ -45,6 +50,7 @@ class SafetyRouter(Safety):
logger.debug(f"SafetyRouter.unregister_shield: {identifier}") logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
return await self.routing_table.unregister_shield(identifier) return await self.routing_table.unregister_shield(identifier)
@tracer.start_as_current_span(name=GUARDRAIL_SPAN_NAME)
async def run_shield( async def run_shield(
self, self,
shield_id: str, shield_id: str,
@ -53,12 +59,15 @@ class SafetyRouter(Safety):
) -> RunShieldResponse: ) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}") logger.debug(f"SafetyRouter.run_shield: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id) provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield( response = await provider.run_shield(
shield_id=shield_id, shield_id=shield_id,
messages=messages, messages=messages,
params=params, params=params,
) )
guardrail_request_span_attributes(shield_id, messages, response)
return response
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject: async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
list_shields_response = await self.routing_table.list_shields() list_shields_response = await self.routing_table.list_shields()
shields = list_shields_response.data shields = list_shields_response.data

View file

@ -3,30 +3,3 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .telemetry import Telemetry
from .trace_protocol import serialize_value, trace_protocol
from .tracing import (
CURRENT_TRACE_CONTEXT,
ROOT_SPAN_MARKERS,
end_trace,
enqueue_event,
get_current_span,
setup_logger,
span,
start_trace,
)
__all__ = [
"Telemetry",
"trace_protocol",
"serialize_value",
"CURRENT_TRACE_CONTEXT",
"ROOT_SPAN_MARKERS",
"end_trace",
"enqueue_event",
"get_current_span",
"setup_logger",
"span",
"start_trace",
]

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Safety Attributes
GUARDRAIL_SPAN_NAME = "llama_stack.guardrail"
SAFETY_REQUEST_PREFIX = "llama_stack.safety.request"
SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.shield_id"
SAFETY_REQUEST_MESSAGES_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.messages"
SAFETY_RESPONSE_PREFIX = "llama_stack.safety.response"
SAFETY_RESPONSE_METADATA_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.metadata"
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.level"
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.user_message"

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from opentelemetry import trace
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse
from .constants import (
SAFETY_REQUEST_MESSAGES_ATTRIBUTE,
SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE,
SAFETY_RESPONSE_METADATA_ATTRIBUTE,
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE,
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE,
)
def guardrail_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None:
span = trace.get_current_span()
if span is not None:
span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id)
messages_json = json.dumps([msg.model_dump() for msg in messages])
span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json)
if response.violation:
if response.violation.metadata:
metadata_json = json.dumps(response.violation.metadata)
span.set_attribute(SAFETY_RESPONSE_METADATA_ATTRIBUTE, metadata_json)
if response.violation.user_message:
span.set_attribute(SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, response.violation.user_message)
span.set_attribute(SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, response.violation.violation_level.value)

View file

@ -6,7 +6,6 @@
import asyncio import asyncio
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel
@ -31,15 +30,12 @@ class ShieldRunnerMixin:
self.output_shields = output_shields self.output_shields = output_shields
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None: async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str): responses = await asyncio.gather(
async with tracing.span(f"run_shield_{identifier}"): *[
return await self.safety_api.run_shield( self.safety_api.run_shield(shield_id=identifier, messages=messages, params={})
shield_id=identifier, for identifier in identifiers
messages=messages, ]
params={},
) )
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
for identifier, response in zip(identifiers, responses, strict=False): for identifier, response in zip(identifiers, responses, strict=False):
if not response.violation: if not response.violation:
continue continue