use guardrails and run_moderation api

This commit is contained in:
Swapna Lekkala 2025-10-13 12:12:53 -07:00
parent 171fb7101d
commit c10db23d7a
16 changed files with 184 additions and 195 deletions

View file

@ -43,17 +43,17 @@ from .openai_responses import (
@json_schema_type
class ResponseShieldSpec(BaseModel):
"""Specification for a shield to apply during response generation.
class ResponseGuardrailSpec(BaseModel):
"""Specification for a guardrail to apply during response generation.
:param type: The type/identifier of the shield.
:param type: The type/identifier of the guardrail.
"""
type: str
# TODO: more fields to be added for shield configuration
# TODO: more fields to be added for guardrail configuration
ResponseShield = str | ResponseShieldSpec
ResponseGuardrail = str | ResponseGuardrailSpec
class Attachment(BaseModel):
@ -218,8 +218,8 @@ register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
input_guardrails: list[str] | None = Field(default_factory=lambda: [])
output_guardrails: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
@ -820,10 +820,10 @@ class Agents(Protocol):
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
shields: Annotated[
list[ResponseShield] | None,
guardrails: Annotated[
list[ResponseGuardrail] | None,
ExtraBodyField(
"List of shields to apply during response generation. Shields provide safety and content moderation."
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
),
] = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
@ -834,7 +834,7 @@ class Agents(Protocol):
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response.
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:returns: An OpenAIResponseObject.
"""
...

View file

@ -338,7 +338,7 @@ class MetaReferenceAgentsImpl(Agents):
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
shields: list | None = None,
guardrails: list | None = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input,
@ -353,7 +353,7 @@ class MetaReferenceAgentsImpl(Agents):
tools,
include,
max_infer_iters,
shields,
guardrails,
)
async def list_openai_responses(

View file

@ -49,7 +49,7 @@ from .types import ChatCompletionContext, ToolContext
from .utils import (
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
extract_shield_ids,
extract_guardrail_ids,
)
logger = get_logger(name=__name__, category="openai_responses")
@ -236,12 +236,12 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
shields: list | None = None,
guardrails: list | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
shield_ids = extract_shield_ids(shields) if shields else []
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
if conversation is not None:
if previous_response_id is not None:
@ -263,7 +263,7 @@ class OpenAIResponsesImpl:
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
shield_ids=shield_ids,
guardrail_ids=guardrail_ids,
)
if stream:
@ -309,7 +309,7 @@ class OpenAIResponsesImpl:
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
shield_ids: list[str] | None = None,
guardrail_ids: list[str] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
all_input, messages, tool_context = await self._process_input_with_previous_response(
@ -345,7 +345,7 @@ class OpenAIResponsesImpl:
max_infer_iters=max_infer_iters,
tool_executor=self.tool_executor,
safety_api=self.safety_api,
shield_ids=shield_ids,
guardrail_ids=guardrail_ids,
)
# Stream the response

View file

@ -56,9 +56,7 @@ from llama_stack.apis.agents.openai_responses import (
WebSearchToolTypes,
)
from llama_stack.apis.inference import (
CompletionMessage,
Inference,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
@ -66,9 +64,10 @@ from llama_stack.apis.inference import (
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
StopReason,
OpenAIUserMessageParam,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.telemetry import tracing
from ..safety import SafetyException
@ -76,7 +75,7 @@ from .types import ChatCompletionContext, ChatCompletionResult
from .utils import (
convert_chat_choice_to_response_message,
is_function_tool_call,
run_multiple_shields,
run_multiple_guardrails,
)
logger = get_logger(name=__name__, category="agents::meta_reference")
@ -114,7 +113,7 @@ class StreamingResponseOrchestrator:
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
safety_api,
shield_ids: list[str] | None = None,
guardrail_ids: list[str] | None = None,
):
self.inference_api = inference_api
self.ctx = ctx
@ -124,7 +123,7 @@ class StreamingResponseOrchestrator:
self.max_infer_iters = max_infer_iters
self.tool_executor = tool_executor
self.safety_api = safety_api
self.shield_ids = shield_ids or []
self.guardrail_ids = guardrail_ids or []
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
@ -137,28 +136,33 @@ class StreamingResponseOrchestrator:
# Track if we've sent a refusal response
self.violation_detected = False
async def _check_input_safety(self, messages: list[Message]) -> OpenAIResponseContentPartRefusal | None:
"""Validate input messages against shields. Returns refusal content if violation found."""
async def _check_input_safety(
self, messages: list[OpenAIUserMessageParam]
) -> OpenAIResponseContentPartRefusal | None:
"""Validate input messages against guardrails. Returns refusal content if violation found."""
combined_text = interleaved_content_as_str([msg.content for msg in messages])
if not combined_text:
return None
try:
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
await run_multiple_guardrails(self.safety_api, combined_text, self.guardrail_ids)
except SafetyException as e:
logger.info(f"Input shield violation: {e.violation.user_message}")
logger.info(f"Input guardrail violation: {e.violation.user_message}")
return OpenAIResponseContentPartRefusal(
refusal=e.violation.user_message or "Content blocked by safety shields"
refusal=e.violation.user_message or "Content blocked by safety guardrails"
)
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
"""Check accumulated streaming text content against shields. Returns violation message if blocked."""
if not self.shield_ids or not accumulated_text:
"""Check accumulated streaming text content against guardrails. Returns violation message if blocked."""
if not self.guardrail_ids or not accumulated_text:
return None
messages = [CompletionMessage(content=accumulated_text, stop_reason=StopReason.end_of_turn)]
try:
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
await run_multiple_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
except SafetyException as e:
logger.info(f"Output shield violation: {e.violation.user_message}")
return e.violation.user_message or "Generated content blocked by safety shields"
logger.info(f"Output guardrail violation: {e.violation.user_message}")
return e.violation.user_message or "Generated content blocked by safety guardrails"
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""
@ -219,7 +223,7 @@ class StreamingResponseOrchestrator:
)
# Input safety validation - check messages before processing
if self.shield_ids:
if self.guardrail_ids:
input_refusal = await self._check_input_safety(self.ctx.messages)
if input_refusal:
# Return refusal response immediately

View file

@ -8,7 +8,7 @@ import asyncio
import re
import uuid
from llama_stack.apis.agents.agents import ResponseShieldSpec
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput,
@ -28,7 +28,6 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText,
)
from llama_stack.apis.inference import (
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
@ -314,38 +313,58 @@ def is_function_tool_call(
return False
async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None:
"""Run multiple shields against messages and raise SafetyException for violations."""
if not shield_ids or not messages:
async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> None:
"""Run multiple guardrails against messages and raise SafetyException for violations."""
if not guardrail_ids or not messages:
return
shield_tasks = [
safety_api.run_shield(shield_id=shield_id, messages=messages, params={}) for shield_id in shield_ids
]
responses = await asyncio.gather(*shield_tasks)
# Look up shields to get their provider_resource_id (actual model ID)
model_ids = []
shields_list = await safety_api.routing_table.list_shields()
for guardrail_id in guardrail_ids:
# Find the shield with this identifier
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
if matching_shields:
model_id = matching_shields[0].provider_resource_id
model_ids.append(model_id)
else:
# If no shield found, raise an error
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
responses = await asyncio.gather(*guardrail_tasks)
for response in responses:
if response.violation and response.violation.violation_level.name == "ERROR":
if response.flagged:
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
from ..safety import SafetyException
raise SafetyException(response.violation)
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message="Content flagged by moderation",
metadata={"categories": response.categories},
)
raise SafetyException(violation)
def extract_shield_ids(shields: list | None) -> list[str]:
"""Extract shield IDs from shields parameter, handling both string IDs and ResponseShieldSpec objects."""
if not shields:
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
if not guardrails:
return []
shield_ids = []
for shield in shields:
if isinstance(shield, str):
shield_ids.append(shield)
elif isinstance(shield, ResponseShieldSpec):
shield_ids.append(shield.type)
guardrail_ids = []
for guardrail in guardrails:
if isinstance(guardrail, str):
guardrail_ids.append(guardrail)
elif isinstance(guardrail, ResponseGuardrailSpec):
guardrail_ids.append(guardrail.type)
else:
raise ValueError(f"Unknown shield format: {shield}, expected str or ResponseShieldSpec")
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
return shield_ids
return guardrail_ids
def extract_text_content(content: str | list | None) -> str | None: