mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix(telemetry): Safety APIs Use OpenTelemetry Natively
This change created a standardized way to handle telemetry internally. All custom names that are not a semantic convention are maintained in constants.py. Helper functions to capture custom telemetry data not captured by automatic instrumentation are in helpers.py.
This commit is contained in:
parent
4c18239914
commit
8bbbfaba5d
5 changed files with 69 additions and 38 deletions
|
|
@ -7,10 +7,15 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import SafetyConfig
|
from llama_stack.core.datatypes import SafetyConfig
|
||||||
|
from llama_stack.core.telemetry.constants import GUARDRAIL_SPAN_NAME
|
||||||
|
from llama_stack.core.telemetry.helpers import guardrail_request_span_attributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
|
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
|
||||||
|
|
||||||
|
from opentelemetry import trace
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core::routers")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
class SafetyRouter(Safety):
|
||||||
|
|
@ -45,6 +50,7 @@ class SafetyRouter(Safety):
|
||||||
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
||||||
return await self.routing_table.unregister_shield(identifier)
|
return await self.routing_table.unregister_shield(identifier)
|
||||||
|
|
||||||
|
@tracer.start_as_current_span(name=GUARDRAIL_SPAN_NAME)
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
|
|
@ -53,12 +59,15 @@ class SafetyRouter(Safety):
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||||
return await provider.run_shield(
|
response = await provider.run_shield(
|
||||||
shield_id=shield_id,
|
shield_id=shield_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
guardrail_request_span_attributes(shield_id, messages, response)
|
||||||
|
return response
|
||||||
|
|
||||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||||
list_shields_response = await self.routing_table.list_shields()
|
list_shields_response = await self.routing_table.list_shields()
|
||||||
shields = list_shields_response.data
|
shields = list_shields_response.data
|
||||||
|
|
|
||||||
|
|
@ -3,30 +3,3 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .telemetry import Telemetry
|
|
||||||
from .trace_protocol import serialize_value, trace_protocol
|
|
||||||
from .tracing import (
|
|
||||||
CURRENT_TRACE_CONTEXT,
|
|
||||||
ROOT_SPAN_MARKERS,
|
|
||||||
end_trace,
|
|
||||||
enqueue_event,
|
|
||||||
get_current_span,
|
|
||||||
setup_logger,
|
|
||||||
span,
|
|
||||||
start_trace,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Telemetry",
|
|
||||||
"trace_protocol",
|
|
||||||
"serialize_value",
|
|
||||||
"CURRENT_TRACE_CONTEXT",
|
|
||||||
"ROOT_SPAN_MARKERS",
|
|
||||||
"end_trace",
|
|
||||||
"enqueue_event",
|
|
||||||
"get_current_span",
|
|
||||||
"setup_logger",
|
|
||||||
"span",
|
|
||||||
"start_trace",
|
|
||||||
]
|
|
||||||
|
|
|
||||||
17
src/llama_stack/core/telemetry/constants.py
Normal file
17
src/llama_stack/core/telemetry/constants.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Safety Attributes
|
||||||
|
GUARDRAIL_SPAN_NAME = "llama_stack.guardrail"
|
||||||
|
|
||||||
|
SAFETY_REQUEST_PREFIX = "llama_stack.safety.request"
|
||||||
|
SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.shield_id"
|
||||||
|
SAFETY_REQUEST_MESSAGES_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.messages"
|
||||||
|
|
||||||
|
SAFETY_RESPONSE_PREFIX = "llama_stack.safety.response"
|
||||||
|
SAFETY_RESPONSE_METADATA_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.metadata"
|
||||||
|
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.level"
|
||||||
|
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.user_message"
|
||||||
36
src/llama_stack/core/telemetry/helpers.py
Normal file
36
src/llama_stack/core/telemetry/helpers.py
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from opentelemetry import trace
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Message
|
||||||
|
from llama_stack.apis.safety import RunShieldResponse
|
||||||
|
|
||||||
|
from .constants import (
|
||||||
|
SAFETY_REQUEST_MESSAGES_ATTRIBUTE,
|
||||||
|
SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE,
|
||||||
|
SAFETY_RESPONSE_METADATA_ATTRIBUTE,
|
||||||
|
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE,
|
||||||
|
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def guardrail_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None:
|
||||||
|
span = trace.get_current_span()
|
||||||
|
if span is not None:
|
||||||
|
span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id)
|
||||||
|
messages_json = json.dumps([msg.model_dump() for msg in messages])
|
||||||
|
span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json)
|
||||||
|
|
||||||
|
if response.violation:
|
||||||
|
if response.violation.metadata:
|
||||||
|
metadata_json = json.dumps(response.violation.metadata)
|
||||||
|
span.set_attribute(SAFETY_RESPONSE_METADATA_ATTRIBUTE, metadata_json)
|
||||||
|
if response.violation.user_message:
|
||||||
|
span.set_attribute(SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, response.violation.user_message)
|
||||||
|
span.set_attribute(SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, response.violation.violation_level.value)
|
||||||
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from llama_stack.core.telemetry import tracing
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel
|
from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel
|
||||||
|
|
||||||
|
|
@ -31,15 +30,12 @@ class ShieldRunnerMixin:
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
||||||
async def run_shield_with_span(identifier: str):
|
responses = await asyncio.gather(
|
||||||
async with tracing.span(f"run_shield_{identifier}"):
|
*[
|
||||||
return await self.safety_api.run_shield(
|
self.safety_api.run_shield(shield_id=identifier, messages=messages, params={})
|
||||||
shield_id=identifier,
|
for identifier in identifiers
|
||||||
messages=messages,
|
]
|
||||||
params={},
|
)
|
||||||
)
|
|
||||||
|
|
||||||
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
|
||||||
for identifier, response in zip(identifiers, responses, strict=False):
|
for identifier, response in zip(identifiers, responses, strict=False):
|
||||||
if not response.violation:
|
if not response.violation:
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue