mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
use guardrails and run_moderation api
This commit is contained in:
parent
171fb7101d
commit
c10db23d7a
16 changed files with 184 additions and 195 deletions
18
docs/static/deprecated-llama-stack-spec.html
vendored
18
docs/static/deprecated-llama-stack-spec.html
vendored
|
|
@ -2135,7 +2135,7 @@
|
|||
"deprecated": true,
|
||||
"x-llama-stack-extra-body-params": [
|
||||
{
|
||||
"name": "shields",
|
||||
"name": "guardrails",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
|
@ -2144,12 +2144,12 @@
|
|||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ResponseShieldSpec"
|
||||
"$ref": "#/components/schemas/ResponseGuardrailSpec"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"description": "List of shields to apply during response generation. Shields provide safety and content moderation.",
|
||||
"description": "List of guardrails to apply during response generation. Guardrails provide safety and content moderation.",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
|
|
@ -3615,13 +3615,13 @@
|
|||
"sampling_params": {
|
||||
"$ref": "#/components/schemas/SamplingParams"
|
||||
},
|
||||
"input_shields": {
|
||||
"input_guardrails": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"output_shields": {
|
||||
"output_guardrails": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
|
|
@ -9606,20 +9606,20 @@
|
|||
"title": "OpenAIResponseUsage",
|
||||
"description": "Usage information for OpenAI response."
|
||||
},
|
||||
"ResponseShieldSpec": {
|
||||
"ResponseGuardrailSpec": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "The type/identifier of the shield."
|
||||
"description": "The type/identifier of the guardrail."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
],
|
||||
"title": "ResponseShieldSpec",
|
||||
"description": "Specification for a shield to apply during response generation."
|
||||
"title": "ResponseGuardrailSpec",
|
||||
"description": "Specification for a guardrail to apply during response generation."
|
||||
},
|
||||
"OpenAIResponseInputTool": {
|
||||
"oneOf": [
|
||||
|
|
|
|||
20
docs/static/deprecated-llama-stack-spec.yaml
vendored
20
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -1569,16 +1569,16 @@ paths:
|
|||
required: true
|
||||
deprecated: true
|
||||
x-llama-stack-extra-body-params:
|
||||
- name: shields
|
||||
- name: guardrails
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/ResponseShieldSpec'
|
||||
- $ref: '#/components/schemas/ResponseGuardrailSpec'
|
||||
description: >-
|
||||
List of shields to apply during response generation. Shields provide safety
|
||||
and content moderation.
|
||||
List of guardrails to apply during response generation. Guardrails provide
|
||||
safety and content moderation.
|
||||
required: false
|
||||
/v1/openai/v1/responses/{response_id}:
|
||||
get:
|
||||
|
|
@ -2667,11 +2667,11 @@ components:
|
|||
properties:
|
||||
sampling_params:
|
||||
$ref: '#/components/schemas/SamplingParams'
|
||||
input_shields:
|
||||
input_guardrails:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
output_shields:
|
||||
output_guardrails:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
|
@ -7177,18 +7177,18 @@ components:
|
|||
- total_tokens
|
||||
title: OpenAIResponseUsage
|
||||
description: Usage information for OpenAI response.
|
||||
ResponseShieldSpec:
|
||||
ResponseGuardrailSpec:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: The type/identifier of the shield.
|
||||
description: The type/identifier of the guardrail.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
title: ResponseShieldSpec
|
||||
title: ResponseGuardrailSpec
|
||||
description: >-
|
||||
Specification for a shield to apply during response generation.
|
||||
Specification for a guardrail to apply during response generation.
|
||||
OpenAIResponseInputTool:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'
|
||||
|
|
|
|||
|
|
@ -2090,13 +2090,13 @@
|
|||
"sampling_params": {
|
||||
"$ref": "#/components/schemas/SamplingParams"
|
||||
},
|
||||
"input_shields": {
|
||||
"input_guardrails": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"output_shields": {
|
||||
"output_guardrails": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
|
|
|
|||
|
|
@ -1500,11 +1500,11 @@ components:
|
|||
properties:
|
||||
sampling_params:
|
||||
$ref: '#/components/schemas/SamplingParams'
|
||||
input_shields:
|
||||
input_guardrails:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
output_shields:
|
||||
output_guardrails:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
|
|
|||
14
docs/static/llama-stack-spec.html
vendored
14
docs/static/llama-stack-spec.html
vendored
|
|
@ -1833,7 +1833,7 @@
|
|||
"deprecated": false,
|
||||
"x-llama-stack-extra-body-params": [
|
||||
{
|
||||
"name": "shields",
|
||||
"name": "guardrails",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
|
@ -1842,12 +1842,12 @@
|
|||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ResponseShieldSpec"
|
||||
"$ref": "#/components/schemas/ResponseGuardrailSpec"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"description": "List of shields to apply during response generation. Shields provide safety and content moderation.",
|
||||
"description": "List of guardrails to apply during response generation. Guardrails provide safety and content moderation.",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
|
|
@ -7854,20 +7854,20 @@
|
|||
"title": "OpenAIResponseUsage",
|
||||
"description": "Usage information for OpenAI response."
|
||||
},
|
||||
"ResponseShieldSpec": {
|
||||
"ResponseGuardrailSpec": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "The type/identifier of the shield."
|
||||
"description": "The type/identifier of the guardrail."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
],
|
||||
"title": "ResponseShieldSpec",
|
||||
"description": "Specification for a shield to apply during response generation."
|
||||
"title": "ResponseGuardrailSpec",
|
||||
"description": "Specification for a guardrail to apply during response generation."
|
||||
},
|
||||
"OpenAIResponseInputTool": {
|
||||
"oneOf": [
|
||||
|
|
|
|||
16
docs/static/llama-stack-spec.yaml
vendored
16
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -1448,16 +1448,16 @@ paths:
|
|||
required: true
|
||||
deprecated: false
|
||||
x-llama-stack-extra-body-params:
|
||||
- name: shields
|
||||
- name: guardrails
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/ResponseShieldSpec'
|
||||
- $ref: '#/components/schemas/ResponseGuardrailSpec'
|
||||
description: >-
|
||||
List of shields to apply during response generation. Shields provide safety
|
||||
and content moderation.
|
||||
List of guardrails to apply during response generation. Guardrails provide
|
||||
safety and content moderation.
|
||||
required: false
|
||||
/v1/responses/{response_id}:
|
||||
get:
|
||||
|
|
@ -5973,18 +5973,18 @@ components:
|
|||
- total_tokens
|
||||
title: OpenAIResponseUsage
|
||||
description: Usage information for OpenAI response.
|
||||
ResponseShieldSpec:
|
||||
ResponseGuardrailSpec:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: The type/identifier of the shield.
|
||||
description: The type/identifier of the guardrail.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
title: ResponseShieldSpec
|
||||
title: ResponseGuardrailSpec
|
||||
description: >-
|
||||
Specification for a shield to apply during response generation.
|
||||
Specification for a guardrail to apply during response generation.
|
||||
OpenAIResponseInputTool:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'
|
||||
|
|
|
|||
18
docs/static/stainless-llama-stack-spec.html
vendored
18
docs/static/stainless-llama-stack-spec.html
vendored
|
|
@ -1833,7 +1833,7 @@
|
|||
"deprecated": false,
|
||||
"x-llama-stack-extra-body-params": [
|
||||
{
|
||||
"name": "shields",
|
||||
"name": "guardrails",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
|
@ -1842,12 +1842,12 @@
|
|||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ResponseShieldSpec"
|
||||
"$ref": "#/components/schemas/ResponseGuardrailSpec"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"description": "List of shields to apply during response generation. Shields provide safety and content moderation.",
|
||||
"description": "List of guardrails to apply during response generation. Guardrails provide safety and content moderation.",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
|
|
@ -9526,20 +9526,20 @@
|
|||
"title": "OpenAIResponseUsage",
|
||||
"description": "Usage information for OpenAI response."
|
||||
},
|
||||
"ResponseShieldSpec": {
|
||||
"ResponseGuardrailSpec": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "The type/identifier of the shield."
|
||||
"description": "The type/identifier of the guardrail."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type"
|
||||
],
|
||||
"title": "ResponseShieldSpec",
|
||||
"description": "Specification for a shield to apply during response generation."
|
||||
"title": "ResponseGuardrailSpec",
|
||||
"description": "Specification for a guardrail to apply during response generation."
|
||||
},
|
||||
"OpenAIResponseInputTool": {
|
||||
"oneOf": [
|
||||
|
|
@ -15192,13 +15192,13 @@
|
|||
"sampling_params": {
|
||||
"$ref": "#/components/schemas/SamplingParams"
|
||||
},
|
||||
"input_shields": {
|
||||
"input_guardrails": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"output_shields": {
|
||||
"output_guardrails": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
|
|
|
|||
20
docs/static/stainless-llama-stack-spec.yaml
vendored
20
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -1451,16 +1451,16 @@ paths:
|
|||
required: true
|
||||
deprecated: false
|
||||
x-llama-stack-extra-body-params:
|
||||
- name: shields
|
||||
- name: guardrails
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/ResponseShieldSpec'
|
||||
- $ref: '#/components/schemas/ResponseGuardrailSpec'
|
||||
description: >-
|
||||
List of shields to apply during response generation. Shields provide safety
|
||||
and content moderation.
|
||||
List of guardrails to apply during response generation. Guardrails provide
|
||||
safety and content moderation.
|
||||
required: false
|
||||
/v1/responses/{response_id}:
|
||||
get:
|
||||
|
|
@ -7186,18 +7186,18 @@ components:
|
|||
- total_tokens
|
||||
title: OpenAIResponseUsage
|
||||
description: Usage information for OpenAI response.
|
||||
ResponseShieldSpec:
|
||||
ResponseGuardrailSpec:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: The type/identifier of the shield.
|
||||
description: The type/identifier of the guardrail.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
title: ResponseShieldSpec
|
||||
title: ResponseGuardrailSpec
|
||||
description: >-
|
||||
Specification for a shield to apply during response generation.
|
||||
Specification for a guardrail to apply during response generation.
|
||||
OpenAIResponseInputTool:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'
|
||||
|
|
@ -11478,11 +11478,11 @@ components:
|
|||
properties:
|
||||
sampling_params:
|
||||
$ref: '#/components/schemas/SamplingParams'
|
||||
input_shields:
|
||||
input_guardrails:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
output_shields:
|
||||
output_guardrails:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
|
|
|||
|
|
@ -43,17 +43,17 @@ from .openai_responses import (
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ResponseShieldSpec(BaseModel):
|
||||
"""Specification for a shield to apply during response generation.
|
||||
class ResponseGuardrailSpec(BaseModel):
|
||||
"""Specification for a guardrail to apply during response generation.
|
||||
|
||||
:param type: The type/identifier of the shield.
|
||||
:param type: The type/identifier of the guardrail.
|
||||
"""
|
||||
|
||||
type: str
|
||||
# TODO: more fields to be added for shield configuration
|
||||
# TODO: more fields to be added for guardrail configuration
|
||||
|
||||
|
||||
ResponseShield = str | ResponseShieldSpec
|
||||
ResponseGuardrail = str | ResponseGuardrailSpec
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
|
|
@ -218,8 +218,8 @@ register_schema(AgentToolGroup, name="AgentTool")
|
|||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
input_guardrails: list[str] | None = Field(default_factory=lambda: [])
|
||||
output_guardrails: list[str] | None = Field(default_factory=lambda: [])
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||
|
|
@ -820,10 +820,10 @@ class Agents(Protocol):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||
shields: Annotated[
|
||||
list[ResponseShield] | None,
|
||||
guardrails: Annotated[
|
||||
list[ResponseGuardrail] | None,
|
||||
ExtraBodyField(
|
||||
"List of shields to apply during response generation. Shields provide safety and content moderation."
|
||||
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
|
||||
),
|
||||
] = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
|
|
@ -834,7 +834,7 @@ class Agents(Protocol):
|
|||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
|
||||
:param include: (Optional) Additional fields to include in the response.
|
||||
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
|
||||
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -338,7 +338,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
guardrails: list | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
|
|
@ -353,7 +353,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools,
|
||||
include,
|
||||
max_infer_iters,
|
||||
shields,
|
||||
guardrails,
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ from .types import ChatCompletionContext, ToolContext
|
|||
from .utils import (
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
extract_shield_ids,
|
||||
extract_guardrail_ids,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
|
@ -236,12 +236,12 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
guardrails: list | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
shield_ids = extract_shield_ids(shields) if shields else []
|
||||
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
|
||||
|
||||
if conversation is not None:
|
||||
if previous_response_id is not None:
|
||||
|
|
@ -263,7 +263,7 @@ class OpenAIResponsesImpl:
|
|||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
shield_ids=shield_ids,
|
||||
guardrail_ids=guardrail_ids,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -309,7 +309,7 @@ class OpenAIResponsesImpl:
|
|||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shield_ids: list[str] | None = None,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
|
|
@ -345,7 +345,7 @@ class OpenAIResponsesImpl:
|
|||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
safety_api=self.safety_api,
|
||||
shield_ids=shield_ids,
|
||||
guardrail_ids=guardrail_ids,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
|
|
|
|||
|
|
@ -56,9 +56,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
Inference,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
|
|
@ -66,9 +64,10 @@ from llama_stack.apis.inference import (
|
|||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
StopReason,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from ..safety import SafetyException
|
||||
|
|
@ -76,7 +75,7 @@ from .types import ChatCompletionContext, ChatCompletionResult
|
|||
from .utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
is_function_tool_call,
|
||||
run_multiple_shields,
|
||||
run_multiple_guardrails,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
|
@ -114,7 +113,7 @@ class StreamingResponseOrchestrator:
|
|||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
safety_api,
|
||||
shield_ids: list[str] | None = None,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
|
|
@ -124,7 +123,7 @@ class StreamingResponseOrchestrator:
|
|||
self.max_infer_iters = max_infer_iters
|
||||
self.tool_executor = tool_executor
|
||||
self.safety_api = safety_api
|
||||
self.shield_ids = shield_ids or []
|
||||
self.guardrail_ids = guardrail_ids or []
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
||||
|
|
@ -137,28 +136,33 @@ class StreamingResponseOrchestrator:
|
|||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
|
||||
async def _check_input_safety(self, messages: list[Message]) -> OpenAIResponseContentPartRefusal | None:
|
||||
"""Validate input messages against shields. Returns refusal content if violation found."""
|
||||
async def _check_input_safety(
|
||||
self, messages: list[OpenAIUserMessageParam]
|
||||
) -> OpenAIResponseContentPartRefusal | None:
|
||||
"""Validate input messages against guardrails. Returns refusal content if violation found."""
|
||||
combined_text = interleaved_content_as_str([msg.content for msg in messages])
|
||||
|
||||
if not combined_text:
|
||||
return None
|
||||
|
||||
try:
|
||||
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
|
||||
await run_multiple_guardrails(self.safety_api, combined_text, self.guardrail_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Input shield violation: {e.violation.user_message}")
|
||||
logger.info(f"Input guardrail violation: {e.violation.user_message}")
|
||||
return OpenAIResponseContentPartRefusal(
|
||||
refusal=e.violation.user_message or "Content blocked by safety shields"
|
||||
refusal=e.violation.user_message or "Content blocked by safety guardrails"
|
||||
)
|
||||
|
||||
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
|
||||
"""Check accumulated streaming text content against shields. Returns violation message if blocked."""
|
||||
if not self.shield_ids or not accumulated_text:
|
||||
"""Check accumulated streaming text content against guardrails. Returns violation message if blocked."""
|
||||
if not self.guardrail_ids or not accumulated_text:
|
||||
return None
|
||||
|
||||
messages = [CompletionMessage(content=accumulated_text, stop_reason=StopReason.end_of_turn)]
|
||||
|
||||
try:
|
||||
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
|
||||
await run_multiple_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Output shield violation: {e.violation.user_message}")
|
||||
return e.violation.user_message or "Generated content blocked by safety shields"
|
||||
logger.info(f"Output guardrail violation: {e.violation.user_message}")
|
||||
return e.violation.user_message or "Generated content blocked by safety guardrails"
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
|
|
@ -219,7 +223,7 @@ class StreamingResponseOrchestrator:
|
|||
)
|
||||
|
||||
# Input safety validation - check messages before processing
|
||||
if self.shield_ids:
|
||||
if self.guardrail_ids:
|
||||
input_refusal = await self._check_input_safety(self.ctx.messages)
|
||||
if input_refusal:
|
||||
# Return refusal response immediately
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import re
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseShieldSpec
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInput,
|
||||
|
|
@ -28,7 +28,6 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
|
|
@ -314,38 +313,58 @@ def is_function_tool_call(
|
|||
return False
|
||||
|
||||
|
||||
async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None:
|
||||
"""Run multiple shields against messages and raise SafetyException for violations."""
|
||||
if not shield_ids or not messages:
|
||||
async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> None:
|
||||
"""Run multiple guardrails against messages and raise SafetyException for violations."""
|
||||
if not guardrail_ids or not messages:
|
||||
return
|
||||
shield_tasks = [
|
||||
safety_api.run_shield(shield_id=shield_id, messages=messages, params={}) for shield_id in shield_ids
|
||||
]
|
||||
|
||||
responses = await asyncio.gather(*shield_tasks)
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
shields_list = await safety_api.routing_table.list_shields()
|
||||
|
||||
for guardrail_id in guardrail_ids:
|
||||
# Find the shield with this identifier
|
||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||
if matching_shields:
|
||||
model_id = matching_shields[0].provider_resource_id
|
||||
model_ids.append(model_id)
|
||||
else:
|
||||
# If no shield found, raise an error
|
||||
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
||||
|
||||
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
|
||||
|
||||
responses = await asyncio.gather(*guardrail_tasks)
|
||||
|
||||
for response in responses:
|
||||
if response.violation and response.violation.violation_level.name == "ERROR":
|
||||
if response.flagged:
|
||||
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
|
||||
|
||||
from ..safety import SafetyException
|
||||
|
||||
raise SafetyException(response.violation)
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message="Content flagged by moderation",
|
||||
metadata={"categories": response.categories},
|
||||
)
|
||||
raise SafetyException(violation)
|
||||
|
||||
|
||||
def extract_shield_ids(shields: list | None) -> list[str]:
|
||||
"""Extract shield IDs from shields parameter, handling both string IDs and ResponseShieldSpec objects."""
|
||||
if not shields:
|
||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
|
||||
if not guardrails:
|
||||
return []
|
||||
|
||||
shield_ids = []
|
||||
for shield in shields:
|
||||
if isinstance(shield, str):
|
||||
shield_ids.append(shield)
|
||||
elif isinstance(shield, ResponseShieldSpec):
|
||||
shield_ids.append(shield.type)
|
||||
guardrail_ids = []
|
||||
for guardrail in guardrails:
|
||||
if isinstance(guardrail, str):
|
||||
guardrail_ids.append(guardrail)
|
||||
elif isinstance(guardrail, ResponseGuardrailSpec):
|
||||
guardrail_ids.append(guardrail.type)
|
||||
else:
|
||||
raise ValueError(f"Unknown shield format: {shield}, expected str or ResponseShieldSpec")
|
||||
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
|
||||
|
||||
return shield_ids
|
||||
return guardrail_ids
|
||||
|
||||
|
||||
def extract_text_content(content: str | list | None) -> str | None:
|
||||
|
|
|
|||
|
|
@ -300,13 +300,13 @@ def test_function_call_output_response_with_none_arguments(openai_client, client
|
|||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_input_safety_shields(compat_client, text_model_id, stream):
|
||||
"""Test that input safety shields block inappropriate content."""
|
||||
def test_input_safety_guardrails(compat_client, text_model_id, stream):
|
||||
"""Test that input safety guardrails block inappropriate content."""
|
||||
if not isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
||||
|
||||
# Test with shield that should block inappropriate content
|
||||
# Note: This test assumes shields are configured in the test environment
|
||||
# Test with guardrail that should block inappropriate content
|
||||
# Note: This test assumes guardrails are configured in the test environment
|
||||
inappropriate_message = "How can I hurt someone?"
|
||||
|
||||
response = compat_client.responses.create(
|
||||
|
|
@ -318,7 +318,7 @@ def test_input_safety_shields(compat_client, text_model_id, stream):
|
|||
}
|
||||
],
|
||||
stream=stream,
|
||||
extra_body={"shields": ["llama-guard"]}, # Shields passed via extra_body
|
||||
extra_body={"guardrails": ["llama-guard"]}, # Guardrails passed via extra_body
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -334,11 +334,11 @@ def test_input_safety_shields(compat_client, text_model_id, stream):
|
|||
else:
|
||||
response_to_check = response
|
||||
|
||||
# Check if the response contains a refusal (indicating shield blocked content)
|
||||
# Check if the response contains a refusal (indicating guardrail blocked content)
|
||||
assert response_to_check.output[0].type == "message"
|
||||
message = response_to_check.output[0]
|
||||
|
||||
# Shield should ALWAYS block inappropriate input - expect refusal response
|
||||
# Guardrail should ALWAYS block inappropriate input - expect refusal response
|
||||
assert len(message.content) > 0, "Content should not be empty"
|
||||
content_item = message.content[0]
|
||||
assert content_item.type == "refusal", f"Expected refusal content, but got: {content_item.type}"
|
||||
|
|
@ -346,16 +346,16 @@ def test_input_safety_shields(compat_client, text_model_id, stream):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_output_safety_shields(compat_client, text_model_id, stream):
|
||||
"""Test that output safety shields can block generated content."""
|
||||
def test_output_safety_guardrails(compat_client, text_model_id, stream):
|
||||
"""Test that output safety guardrails can block generated content."""
|
||||
if not isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
||||
|
||||
# Use an ambiguous prompt that could potentially generate content requiring shield evaluation
|
||||
# Use an ambiguous prompt that could potentially generate content requiring guardrail evaluation
|
||||
# This prompt: "Describe violent content."
|
||||
# Could result in:
|
||||
# 1. Educational violent content → shield allows output → normal response
|
||||
# 2. Violent techniques content → shield blocks output → refusal response
|
||||
# 1. Educational violent content → guardrail allows output → normal response
|
||||
# 2. Violent techniques content → guardrail blocks output → refusal response
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=[
|
||||
|
|
@ -365,7 +365,7 @@ def test_output_safety_shields(compat_client, text_model_id, stream):
|
|||
}
|
||||
],
|
||||
stream=stream,
|
||||
extra_body={"shields": ["llama-guard"]}, # Output shield validation
|
||||
extra_body={"guardrails": ["llama-guard"]}, # Output guardrail validation
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -388,8 +388,8 @@ def test_output_safety_shields(compat_client, text_model_id, stream):
|
|||
assert content_item.type == "refusal", f"Content type should be 'refusal', got {content_item.type}"
|
||||
|
||||
|
||||
def test_shields_with_tools(compat_client, text_model_id):
|
||||
"""Test that shields work correctly when tools are present."""
|
||||
def test_guardrails_with_tools(compat_client, text_model_id):
|
||||
"""Test that guardrails work correctly when tools are present."""
|
||||
if not isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
||||
|
||||
|
|
@ -414,11 +414,11 @@ def test_shields_with_tools(compat_client, text_model_id):
|
|||
},
|
||||
}
|
||||
],
|
||||
extra_body={"shields": ["llama-guard"]},
|
||||
extra_body={"guardrails": ["llama-guard"]},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify response completes successfully with tools and shields
|
||||
# Verify response completes successfully with tools and guardrails
|
||||
assert response.id is not None
|
||||
assert len(response.output) > 0
|
||||
|
||||
|
|
|
|||
|
|
@ -1,34 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test for extra_body parameter support with shields example.
|
||||
|
||||
This test demonstrates that parameters marked with ExtraBodyField annotation
|
||||
can be passed via extra_body in the client SDK and are received by the
|
||||
server-side implementation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import APIStatusError
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Shields are not yet implemented inside responses")
|
||||
def test_shields_via_extra_body(compat_client, text_model_id):
|
||||
"""Test that shields parameter is received by the server and raises NotImplementedError."""
|
||||
|
||||
# Test with shields as list of strings (shield IDs)
|
||||
with pytest.raises((APIStatusError, NotImplementedError)) as exc_info:
|
||||
compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What is the capital of France?",
|
||||
stream=False,
|
||||
extra_body={"shields": ["test-shield-1", "test-shield-2"]},
|
||||
)
|
||||
|
||||
# Verify the error message indicates shields are not implemented
|
||||
error_message = str(exc_info.value)
|
||||
assert "not yet implemented" in error_message.lower() or "not implemented" in error_message.lower()
|
||||
|
|
@ -8,12 +8,12 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseShieldSpec
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
extract_shield_ids,
|
||||
extract_guardrail_ids,
|
||||
extract_text_content,
|
||||
)
|
||||
|
||||
|
|
@ -38,53 +38,53 @@ def responses_impl(mock_apis):
|
|||
return OpenAIResponsesImpl(**mock_apis)
|
||||
|
||||
|
||||
def test_extract_shield_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string shield IDs."""
|
||||
shields = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_shield_ids(shields)
|
||||
def test_extract_guardrail_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string guardrail IDs."""
|
||||
guardrails = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseShieldSpec objects."""
|
||||
shields = [
|
||||
ResponseShieldSpec(type="llama-guard"),
|
||||
ResponseShieldSpec(type="content-filter"),
|
||||
def test_extract_guardrail_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseGuardrailSpec objects."""
|
||||
guardrails = [
|
||||
ResponseGuardrailSpec(type="llama-guard"),
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
]
|
||||
result = extract_shield_ids(shields)
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_mixed_formats(responses_impl):
|
||||
def test_extract_guardrail_ids_mixed_formats(responses_impl):
|
||||
"""Test extraction from mixed string and object formats."""
|
||||
shields = [
|
||||
guardrails = [
|
||||
"llama-guard",
|
||||
ResponseShieldSpec(type="content-filter"),
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
"nsfw-detector",
|
||||
]
|
||||
result = extract_shield_ids(shields)
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_none_input(responses_impl):
|
||||
def test_extract_guardrail_ids_none_input(responses_impl):
|
||||
"""Test extraction with None input."""
|
||||
result = extract_shield_ids(None)
|
||||
result = extract_guardrail_ids(None)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_shield_ids_empty_list(responses_impl):
|
||||
def test_extract_guardrail_ids_empty_list(responses_impl):
|
||||
"""Test extraction with empty list."""
|
||||
result = extract_shield_ids([])
|
||||
result = extract_guardrail_ids([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_shield_ids_unknown_format(responses_impl):
|
||||
"""Test extraction with unknown shield format raises ValueError."""
|
||||
# Create an object that's neither string nor ResponseShieldSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec
|
||||
shields = ["valid-shield", unknown_object, "another-shield"]
|
||||
with pytest.raises(ValueError, match="Unknown shield format.*expected str or ResponseShieldSpec"):
|
||||
extract_shield_ids(shields)
|
||||
def test_extract_guardrail_ids_unknown_format(responses_impl):
|
||||
"""Test extraction with unknown guardrail format raises ValueError."""
|
||||
# Create an object that's neither string nor ResponseGuardrailSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec
|
||||
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
||||
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
||||
extract_guardrail_ids(guardrails)
|
||||
|
||||
|
||||
def test_extract_text_content_string(responses_impl):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue