mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix(telemetry): fixes based on testing
This commit is contained in:
parent
9e2b92b8d2
commit
420d267364
5 changed files with 46 additions and 29 deletions
|
|
@ -371,6 +371,12 @@ class SafetyConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
"""Configuration for telemetry collection."""
|
||||
|
||||
enabled: bool = Field(default=False, description="Whether telemetry collection is enabled")
|
||||
|
||||
|
||||
class QuotaPeriod(StrEnum):
|
||||
DAY = "day"
|
||||
|
||||
|
|
@ -536,6 +542,11 @@ can be instantiated multiple times (with different configs) if necessary.
|
|||
description="Configuration for default moderations model",
|
||||
)
|
||||
|
||||
telemetry: TelemetryConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for telemetry collection",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
|
|
|
|||
|
|
@ -85,7 +85,6 @@ async def get_auto_router_impl(
|
|||
)
|
||||
await inference_store.initialize()
|
||||
api_to_dep_impl["store"] = inference_store
|
||||
|
||||
elif api == Api.vector_io:
|
||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||
elif api == Api.safety:
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import SafetyConfig
|
||||
from llama_stack.core.telemetry.constants import GUARDRAIL_SPAN_NAME
|
||||
from llama_stack.core.telemetry.helpers import guardrail_request_span_attributes
|
||||
from llama_stack.core.telemetry.helpers import safety_request_span_attributes, safety_span_name
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
|
||||
|
||||
|
|
@ -50,22 +49,22 @@ class SafetyRouter(Safety):
|
|||
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
||||
return await self.routing_table.unregister_shield(identifier)
|
||||
|
||||
@tracer.start_as_current_span(name=GUARDRAIL_SPAN_NAME)
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
response = await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
with tracer.start_as_current_span(name=safety_span_name(shield_id)):
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
response = await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
||||
guardrail_request_span_attributes(shield_id, messages, response)
|
||||
safety_request_span_attributes(shield_id, messages, response)
|
||||
return response
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
|
|
|
|||
|
|
@ -4,14 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Safety Attributes
|
||||
GUARDRAIL_SPAN_NAME = "llama_stack.guardrail"
|
||||
llama_stack_prefix = "llama_stack"
|
||||
|
||||
SAFETY_REQUEST_PREFIX = "llama_stack.safety.request"
|
||||
# Safety Attributes
|
||||
SAFETY_SPAN_NAME = "safety.run_shield"
|
||||
|
||||
SAFETY_REQUEST_PREFIX = f"{llama_stack_prefix}.safety.request"
|
||||
SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.shield_id"
|
||||
SAFETY_REQUEST_MESSAGES_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.messages"
|
||||
|
||||
SAFETY_RESPONSE_PREFIX = "llama_stack.safety.response"
|
||||
SAFETY_RESPONSE_PREFIX = f"{llama_stack_prefix}.safety.response"
|
||||
SAFETY_RESPONSE_METADATA_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.metadata"
|
||||
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.level"
|
||||
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.user_message"
|
||||
|
|
|
|||
|
|
@ -17,20 +17,26 @@ from .constants import (
|
|||
SAFETY_RESPONSE_METADATA_ATTRIBUTE,
|
||||
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE,
|
||||
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE,
|
||||
SAFETY_SPAN_NAME,
|
||||
)
|
||||
|
||||
|
||||
def guardrail_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None:
|
||||
span = trace.get_current_span()
|
||||
if span is not None:
|
||||
span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id)
|
||||
messages_json = json.dumps([msg.model_dump() for msg in messages])
|
||||
span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json)
|
||||
def safety_span_name(shield_id: str) -> str:
|
||||
return f"{SAFETY_SPAN_NAME} {shield_id}"
|
||||
|
||||
if response.violation:
|
||||
if response.violation.metadata:
|
||||
metadata_json = json.dumps(response.violation.metadata)
|
||||
span.set_attribute(SAFETY_RESPONSE_METADATA_ATTRIBUTE, metadata_json)
|
||||
if response.violation.user_message:
|
||||
span.set_attribute(SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, response.violation.user_message)
|
||||
span.set_attribute(SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, response.violation.violation_level.value)
|
||||
|
||||
# TODO: Consider using Wrapt to automatically instrument code
|
||||
# This is the industry standard way to package automatically instrumentation in python.
|
||||
def safety_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None:
|
||||
span = trace.get_current_span()
|
||||
span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id)
|
||||
messages_json = json.dumps([msg.model_dump() for msg in messages])
|
||||
span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json)
|
||||
|
||||
if response.violation:
|
||||
if response.violation.metadata:
|
||||
metadata_json = json.dumps(response.violation.metadata)
|
||||
span.set_attribute(SAFETY_RESPONSE_METADATA_ATTRIBUTE, metadata_json)
|
||||
if response.violation.user_message:
|
||||
span.set_attribute(SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, response.violation.user_message)
|
||||
span.set_attribute(SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, response.violation.violation_level.value)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue