diff --git a/src/llama_stack/core/datatypes.py b/src/llama_stack/core/datatypes.py index f64286ef5..3a0f891b0 100644 --- a/src/llama_stack/core/datatypes.py +++ b/src/llama_stack/core/datatypes.py @@ -371,6 +371,12 @@ class SafetyConfig(BaseModel): ) +class TelemetryConfig(BaseModel): + """Configuration for telemetry collection.""" + + enabled: bool = Field(default=False, description="Whether telemetry collection is enabled") + + class QuotaPeriod(StrEnum): DAY = "day" @@ -536,6 +542,11 @@ can be instantiated multiple times (with different configs) if necessary. description="Configuration for default moderations model", ) + telemetry: TelemetryConfig | None = Field( + default=None, + description="Configuration for telemetry collection", + ) + @field_validator("external_providers_dir") @classmethod def validate_external_providers_dir(cls, v): diff --git a/src/llama_stack/core/routers/__init__.py b/src/llama_stack/core/routers/__init__.py index 087fab246..c6f8a7ac2 100644 --- a/src/llama_stack/core/routers/__init__.py +++ b/src/llama_stack/core/routers/__init__.py @@ -85,7 +85,6 @@ async def get_auto_router_impl( ) await inference_store.initialize() api_to_dep_impl["store"] = inference_store - elif api == Api.vector_io: api_to_dep_impl["vector_stores_config"] = run_config.vector_stores elif api == Api.safety: diff --git a/src/llama_stack/core/routers/safety.py b/src/llama_stack/core/routers/safety.py index ee9dca8c5..cbf5215a1 100644 --- a/src/llama_stack/core/routers/safety.py +++ b/src/llama_stack/core/routers/safety.py @@ -7,8 +7,7 @@ from typing import Any from llama_stack.core.datatypes import SafetyConfig -from llama_stack.core.telemetry.constants import GUARDRAIL_SPAN_NAME -from llama_stack.core.telemetry.helpers import guardrail_request_span_attributes +from llama_stack.core.telemetry.helpers import safety_request_span_attributes, safety_span_name from llama_stack.log import get_logger from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield @@ -50,22 +49,22 @@ class SafetyRouter(Safety): logger.debug(f"SafetyRouter.unregister_shield: {identifier}") return await self.routing_table.unregister_shield(identifier) - @tracer.start_as_current_span(name=GUARDRAIL_SPAN_NAME) async def run_shield( self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] = None, ) -> RunShieldResponse: - logger.debug(f"SafetyRouter.run_shield: {shield_id}") - provider = await self.routing_table.get_provider_impl(shield_id) - response = await provider.run_shield( - shield_id=shield_id, - messages=messages, - params=params, - ) + with tracer.start_as_current_span(name=safety_span_name(shield_id)): + logger.debug(f"SafetyRouter.run_shield: {shield_id}") + provider = await self.routing_table.get_provider_impl(shield_id) + response = await provider.run_shield( + shield_id=shield_id, + messages=messages, + params=params, + ) - guardrail_request_span_attributes(shield_id, messages, response) + safety_request_span_attributes(shield_id, messages, response) return response async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject: diff --git a/src/llama_stack/core/telemetry/constants.py b/src/llama_stack/core/telemetry/constants.py index 2a7051dc3..6e3ee8d07 100644 --- a/src/llama_stack/core/telemetry/constants.py +++ b/src/llama_stack/core/telemetry/constants.py @@ -4,14 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Safety Attributes -GUARDRAIL_SPAN_NAME = "llama_stack.guardrail" +llama_stack_prefix = "llama_stack" -SAFETY_REQUEST_PREFIX = "llama_stack.safety.request" +# Safety Attributes +SAFETY_SPAN_NAME = "safety.run_shield" + +SAFETY_REQUEST_PREFIX = f"{llama_stack_prefix}.safety.request" SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.shield_id" SAFETY_REQUEST_MESSAGES_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.messages" -SAFETY_RESPONSE_PREFIX = "llama_stack.safety.response" +SAFETY_RESPONSE_PREFIX = f"{llama_stack_prefix}.safety.response" SAFETY_RESPONSE_METADATA_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.metadata" SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.level" SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.user_message" diff --git a/src/llama_stack/core/telemetry/helpers.py b/src/llama_stack/core/telemetry/helpers.py index f7759e205..5956d7591 100644 --- a/src/llama_stack/core/telemetry/helpers.py +++ b/src/llama_stack/core/telemetry/helpers.py @@ -17,20 +17,26 @@ from .constants import ( SAFETY_RESPONSE_METADATA_ATTRIBUTE, SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, + SAFETY_SPAN_NAME, ) -def guardrail_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None: - span = trace.get_current_span() - if span is not None: - span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id) - messages_json = json.dumps([msg.model_dump() for msg in messages]) - span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json) +def safety_span_name(shield_id: str) -> str: + return f"{SAFETY_SPAN_NAME} {shield_id}" - if response.violation: - if response.violation.metadata: - metadata_json = json.dumps(response.violation.metadata) - span.set_attribute(SAFETY_RESPONSE_METADATA_ATTRIBUTE, metadata_json) - if response.violation.user_message: - span.set_attribute(SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, response.violation.user_message) - span.set_attribute(SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, response.violation.violation_level.value) + +# TODO: Consider using Wrapt to automatically instrument code +# This is the industry standard way to package automatically instrumentation in python. +def safety_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None: + span = trace.get_current_span() + span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id) + messages_json = json.dumps([msg.model_dump() for msg in messages]) + span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json) + + if response.violation: + if response.violation.metadata: + metadata_json = json.dumps(response.violation.metadata) + span.set_attribute(SAFETY_RESPONSE_METADATA_ATTRIBUTE, metadata_json) + if response.violation.user_message: + span.set_attribute(SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, response.violation.user_message) + span.set_attribute(SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, response.violation.violation_level.value)