diff --git a/src/llama_stack_api/agents.py b/src/llama_stack_api/agents.py index 8d3b489e1..a73bf918c 100644 --- a/src/llama_stack_api/agents.py +++ b/src/llama_stack_api/agents.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from collections.abc import AsyncIterator -from typing import Annotated, Protocol, runtime_checkable +from typing import Annotated, Protocol, runtime_checkable, Any, Literal from pydantic import BaseModel @@ -30,12 +30,60 @@ from .openai_responses import ( class ResponseGuardrailSpec(BaseModel): """Specification for a guardrail to apply during response generation. - :param type: The type/identifier of the guardrail. + Production-focused configuration allowing safety, moderation, and policy controls. + + Fields + ------ + type: Identifier for the guardrail implementation (e.g. 'llama-guard', 'content-filter'). + description: Human readable explanation / purpose. + enabled: Whether enforcement is active (default True). + severity: Severity classification for violations (info | warn | block). + action: Action when violation occurs (flag | block | redact | annotate). If omitted, provider default applies. + policy_id: Optional external policy/reference identifier to map violations to organizational policy. + version: Optional version of this guardrail configuration (for audit/rollback). + categories: List of safety/moderation categories this guardrail targets (e.g. ['violence','self-harm']). + thresholds: Per-category numeric thresholds (e.g. {'violence':0.8}). Semantics depend on provider. + max_violations: If set, cap number of violations before early termination. + config: Provider/model specific free-form settings (nested allowed). + tags: Arbitrary labels to assist analytics/telemetry and routing. + metadata: Arbitrary supplemental structured metadata for downstream logging. """ type: str - # TODO: more fields to be added for guardrail configuration + description: str | None = None + enabled: bool = True + severity: Literal["info", "warn", "block"] | None = None + action: Literal["flag", "block", "redact", "annotate"] | None = None + policy_id: str | None = None + version: str | None = None + categories: list[str] | None = None + thresholds: dict[str, float] | None = None + max_violations: int | None = None + config: dict[str, Any] | None = None + tags: list[str] | None = None + metadata: dict[str, Any] | None = None + model_config = { + "extra": "forbid", + "title": "ResponseGuardrailSpec", + } + + @classmethod + def _non_empty(cls, value: str, field_name: str) -> str: # internal helper + if not value or not value.strip(): + raise ValueError(f"{field_name} cannot be empty") + return value + + @classmethod + def validate(cls, value: Any): # pydantic v2 uses model validators; minimal safeguard here if invoked directly + return value + + def normalized(self) -> "ResponseGuardrailSpec": + """Return a normalized copy (e.g., lower-casing categories, stripping whitespace).""" + if self.categories: + object.__setattr__(self, "categories", [c.strip().lower() for c in self.categories]) + return self + ResponseGuardrail = str | ResponseGuardrailSpec