fix(telemetry): fixes based on testing

This commit is contained in:
Emilio Garcia 2025-11-13 15:20:18 -05:00
parent 8b46d5966b
commit ce92a44d08
5 changed files with 46 additions and 29 deletions

View file

@ -371,6 +371,12 @@ class SafetyConfig(BaseModel):
)
class TelemetryConfig(BaseModel):
"""Configuration for telemetry collection."""
enabled: bool = Field(default=False, description="Whether telemetry collection is enabled")
class QuotaPeriod(StrEnum):
DAY = "day"
@ -536,6 +542,11 @@ can be instantiated multiple times (with different configs) if necessary.
description="Configuration for default moderations model",
)
telemetry: TelemetryConfig | None = Field(
default=None,
description="Configuration for telemetry collection",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):

View file

@ -85,7 +85,6 @@ async def get_auto_router_impl(
)
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store
elif api == Api.vector_io:
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
elif api == Api.safety:

View file

@ -7,8 +7,7 @@
from typing import Any
from llama_stack.core.datatypes import SafetyConfig
from llama_stack.core.telemetry.constants import GUARDRAIL_SPAN_NAME
from llama_stack.core.telemetry.helpers import guardrail_request_span_attributes
from llama_stack.core.telemetry.helpers import safety_request_span_attributes, safety_span_name
from llama_stack.log import get_logger
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
@ -50,22 +49,22 @@ class SafetyRouter(Safety):
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
return await self.routing_table.unregister_shield(identifier)
@tracer.start_as_current_span(name=GUARDRAIL_SPAN_NAME)
async def run_shield(
self,
shield_id: str,
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
response = await provider.run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)
with tracer.start_as_current_span(name=safety_span_name(shield_id)):
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
response = await provider.run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)
guardrail_request_span_attributes(shield_id, messages, response)
safety_request_span_attributes(shield_id, messages, response)
return response
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:

View file

@ -4,14 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Safety Attributes
GUARDRAIL_SPAN_NAME = "llama_stack.guardrail"
llama_stack_prefix = "llama_stack"
SAFETY_REQUEST_PREFIX = "llama_stack.safety.request"
# Safety Attributes
SAFETY_SPAN_NAME = "safety.run_shield"
SAFETY_REQUEST_PREFIX = f"{llama_stack_prefix}.safety.request"
SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.shield_id"
SAFETY_REQUEST_MESSAGES_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.messages"
SAFETY_RESPONSE_PREFIX = "llama_stack.safety.response"
SAFETY_RESPONSE_PREFIX = f"{llama_stack_prefix}.safety.response"
SAFETY_RESPONSE_METADATA_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.metadata"
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.level"
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.user_message"

View file

@ -17,20 +17,26 @@ from .constants import (
SAFETY_RESPONSE_METADATA_ATTRIBUTE,
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE,
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE,
SAFETY_SPAN_NAME,
)
def guardrail_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None:
span = trace.get_current_span()
if span is not None:
span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id)
messages_json = json.dumps([msg.model_dump() for msg in messages])
span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json)
def safety_span_name(shield_id: str) -> str:
return f"{SAFETY_SPAN_NAME} {shield_id}"
if response.violation:
if response.violation.metadata:
metadata_json = json.dumps(response.violation.metadata)
span.set_attribute(SAFETY_RESPONSE_METADATA_ATTRIBUTE, metadata_json)
if response.violation.user_message:
span.set_attribute(SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, response.violation.user_message)
span.set_attribute(SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, response.violation.violation_level.value)
# TODO: Consider using Wrapt to automatically instrument code
# This is the industry standard way to package automatically instrumentation in python.
def safety_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None:
span = trace.get_current_span()
span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id)
messages_json = json.dumps([msg.model_dump() for msg in messages])
span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json)
if response.violation:
if response.violation.metadata:
metadata_json = json.dumps(response.violation.metadata)
span.set_attribute(SAFETY_RESPONSE_METADATA_ATTRIBUTE, metadata_json)
if response.violation.user_message:
span.set_attribute(SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, response.violation.user_message)
span.set_attribute(SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, response.violation.violation_level.value)