fix(telemetry): fixes based on testing

This commit is contained in:
Emilio Garcia 2025-11-13 15:20:18 -05:00
parent 9e2b92b8d2
commit 420d267364
5 changed files with 46 additions and 29 deletions

View file

@ -371,6 +371,12 @@ class SafetyConfig(BaseModel):
) )
class TelemetryConfig(BaseModel):
"""Configuration for telemetry collection."""
enabled: bool = Field(default=False, description="Whether telemetry collection is enabled")
class QuotaPeriod(StrEnum): class QuotaPeriod(StrEnum):
DAY = "day" DAY = "day"
@ -536,6 +542,11 @@ can be instantiated multiple times (with different configs) if necessary.
description="Configuration for default moderations model", description="Configuration for default moderations model",
) )
telemetry: TelemetryConfig | None = Field(
default=None,
description="Configuration for telemetry collection",
)
@field_validator("external_providers_dir") @field_validator("external_providers_dir")
@classmethod @classmethod
def validate_external_providers_dir(cls, v): def validate_external_providers_dir(cls, v):

View file

@ -85,7 +85,6 @@ async def get_auto_router_impl(
) )
await inference_store.initialize() await inference_store.initialize()
api_to_dep_impl["store"] = inference_store api_to_dep_impl["store"] = inference_store
elif api == Api.vector_io: elif api == Api.vector_io:
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
elif api == Api.safety: elif api == Api.safety:

View file

@ -7,8 +7,7 @@
from typing import Any from typing import Any
from llama_stack.core.datatypes import SafetyConfig from llama_stack.core.datatypes import SafetyConfig
from llama_stack.core.telemetry.constants import GUARDRAIL_SPAN_NAME from llama_stack.core.telemetry.helpers import safety_request_span_attributes, safety_span_name
from llama_stack.core.telemetry.helpers import guardrail_request_span_attributes
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
@ -50,13 +49,13 @@ class SafetyRouter(Safety):
logger.debug(f"SafetyRouter.unregister_shield: {identifier}") logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
return await self.routing_table.unregister_shield(identifier) return await self.routing_table.unregister_shield(identifier)
@tracer.start_as_current_span(name=GUARDRAIL_SPAN_NAME)
async def run_shield( async def run_shield(
self, self,
shield_id: str, shield_id: str,
messages: list[OpenAIMessageParam], messages: list[OpenAIMessageParam],
params: dict[str, Any] = None, params: dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
with tracer.start_as_current_span(name=safety_span_name(shield_id)):
logger.debug(f"SafetyRouter.run_shield: {shield_id}") logger.debug(f"SafetyRouter.run_shield: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id) provider = await self.routing_table.get_provider_impl(shield_id)
response = await provider.run_shield( response = await provider.run_shield(
@ -65,7 +64,7 @@ class SafetyRouter(Safety):
params=params, params=params,
) )
guardrail_request_span_attributes(shield_id, messages, response) safety_request_span_attributes(shield_id, messages, response)
return response return response
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject: async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:

View file

@ -4,14 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Safety Attributes llama_stack_prefix = "llama_stack"
GUARDRAIL_SPAN_NAME = "llama_stack.guardrail"
SAFETY_REQUEST_PREFIX = "llama_stack.safety.request" # Safety Attributes
SAFETY_SPAN_NAME = "safety.run_shield"
SAFETY_REQUEST_PREFIX = f"{llama_stack_prefix}.safety.request"
SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.shield_id" SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.shield_id"
SAFETY_REQUEST_MESSAGES_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.messages" SAFETY_REQUEST_MESSAGES_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.messages"
SAFETY_RESPONSE_PREFIX = "llama_stack.safety.response" SAFETY_RESPONSE_PREFIX = f"{llama_stack_prefix}.safety.response"
SAFETY_RESPONSE_METADATA_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.metadata" SAFETY_RESPONSE_METADATA_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.metadata"
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.level" SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.level"
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.user_message" SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.user_message"

View file

@ -17,12 +17,18 @@ from .constants import (
SAFETY_RESPONSE_METADATA_ATTRIBUTE, SAFETY_RESPONSE_METADATA_ATTRIBUTE,
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE,
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE,
SAFETY_SPAN_NAME,
) )
def guardrail_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None: def safety_span_name(shield_id: str) -> str:
return f"{SAFETY_SPAN_NAME} {shield_id}"
# TODO: Consider using Wrapt to automatically instrument code
# This is the industry standard way to package automatically instrumentation in python.
def safety_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None:
span = trace.get_current_span() span = trace.get_current_span()
if span is not None:
span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id) span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id)
messages_json = json.dumps([msg.model_dump() for msg in messages]) messages_json = json.dumps([msg.model_dump() for msg in messages])
span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json) span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json)