mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix(telemetry): fixes based on testing
This commit is contained in:
parent
9e2b92b8d2
commit
420d267364
5 changed files with 46 additions and 29 deletions
|
|
@ -371,6 +371,12 @@ class SafetyConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TelemetryConfig(BaseModel):
|
||||||
|
"""Configuration for telemetry collection."""
|
||||||
|
|
||||||
|
enabled: bool = Field(default=False, description="Whether telemetry collection is enabled")
|
||||||
|
|
||||||
|
|
||||||
class QuotaPeriod(StrEnum):
|
class QuotaPeriod(StrEnum):
|
||||||
DAY = "day"
|
DAY = "day"
|
||||||
|
|
||||||
|
|
@ -536,6 +542,11 @@ can be instantiated multiple times (with different configs) if necessary.
|
||||||
description="Configuration for default moderations model",
|
description="Configuration for default moderations model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
telemetry: TelemetryConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Configuration for telemetry collection",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("external_providers_dir")
|
@field_validator("external_providers_dir")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_external_providers_dir(cls, v):
|
def validate_external_providers_dir(cls, v):
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,6 @@ async def get_auto_router_impl(
|
||||||
)
|
)
|
||||||
await inference_store.initialize()
|
await inference_store.initialize()
|
||||||
api_to_dep_impl["store"] = inference_store
|
api_to_dep_impl["store"] = inference_store
|
||||||
|
|
||||||
elif api == Api.vector_io:
|
elif api == Api.vector_io:
|
||||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,7 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import SafetyConfig
|
from llama_stack.core.datatypes import SafetyConfig
|
||||||
from llama_stack.core.telemetry.constants import GUARDRAIL_SPAN_NAME
|
from llama_stack.core.telemetry.helpers import safety_request_span_attributes, safety_span_name
|
||||||
from llama_stack.core.telemetry.helpers import guardrail_request_span_attributes
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
|
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
|
||||||
|
|
||||||
|
|
@ -50,13 +49,13 @@ class SafetyRouter(Safety):
|
||||||
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
||||||
return await self.routing_table.unregister_shield(identifier)
|
return await self.routing_table.unregister_shield(identifier)
|
||||||
|
|
||||||
@tracer.start_as_current_span(name=GUARDRAIL_SPAN_NAME)
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
messages: list[OpenAIMessageParam],
|
messages: list[OpenAIMessageParam],
|
||||||
params: dict[str, Any] = None,
|
params: dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
with tracer.start_as_current_span(name=safety_span_name(shield_id)):
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||||
response = await provider.run_shield(
|
response = await provider.run_shield(
|
||||||
|
|
@ -65,7 +64,7 @@ class SafetyRouter(Safety):
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
guardrail_request_span_attributes(shield_id, messages, response)
|
safety_request_span_attributes(shield_id, messages, response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Safety Attributes
|
llama_stack_prefix = "llama_stack"
|
||||||
GUARDRAIL_SPAN_NAME = "llama_stack.guardrail"
|
|
||||||
|
|
||||||
SAFETY_REQUEST_PREFIX = "llama_stack.safety.request"
|
# Safety Attributes
|
||||||
|
SAFETY_SPAN_NAME = "safety.run_shield"
|
||||||
|
|
||||||
|
SAFETY_REQUEST_PREFIX = f"{llama_stack_prefix}.safety.request"
|
||||||
SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.shield_id"
|
SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.shield_id"
|
||||||
SAFETY_REQUEST_MESSAGES_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.messages"
|
SAFETY_REQUEST_MESSAGES_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.messages"
|
||||||
|
|
||||||
SAFETY_RESPONSE_PREFIX = "llama_stack.safety.response"
|
SAFETY_RESPONSE_PREFIX = f"{llama_stack_prefix}.safety.response"
|
||||||
SAFETY_RESPONSE_METADATA_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.metadata"
|
SAFETY_RESPONSE_METADATA_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.metadata"
|
||||||
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.level"
|
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.level"
|
||||||
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.user_message"
|
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.user_message"
|
||||||
|
|
|
||||||
|
|
@ -17,12 +17,18 @@ from .constants import (
|
||||||
SAFETY_RESPONSE_METADATA_ATTRIBUTE,
|
SAFETY_RESPONSE_METADATA_ATTRIBUTE,
|
||||||
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE,
|
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE,
|
||||||
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE,
|
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE,
|
||||||
|
SAFETY_SPAN_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def guardrail_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None:
|
def safety_span_name(shield_id: str) -> str:
|
||||||
|
return f"{SAFETY_SPAN_NAME} {shield_id}"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Consider using Wrapt to automatically instrument code
|
||||||
|
# This is the industry standard way to package automatically instrumentation in python.
|
||||||
|
def safety_request_span_attributes(shield_id: str, messages: list[Message], response: RunShieldResponse) -> None:
|
||||||
span = trace.get_current_span()
|
span = trace.get_current_span()
|
||||||
if span is not None:
|
|
||||||
span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id)
|
span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id)
|
||||||
messages_json = json.dumps([msg.model_dump() for msg in messages])
|
messages_json = json.dumps([msg.model_dump() for msg in messages])
|
||||||
span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json)
|
span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue