feat: Add responses and safety impl with extra body

This commit is contained in:
Swapna Lekkala 2025-10-10 07:12:51 -07:00
parent 6954fe2274
commit 9152efa1a9
18 changed files with 833 additions and 164 deletions

View file

@ -8821,6 +8821,28 @@
} }
} }
}, },
"OpenAIResponseContentPartRefusal": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "refusal",
"default": "refusal",
"description": "Content part type identifier, always \"refusal\""
},
"refusal": {
"type": "string",
"description": "Refusal text supplied by the model"
}
},
"additionalProperties": false,
"required": [
"type",
"refusal"
],
"title": "OpenAIResponseContentPartRefusal",
"description": "Refusal content within a streamed response part."
},
"OpenAIResponseError": { "OpenAIResponseError": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9395,6 +9417,23 @@
} }
}, },
"OpenAIResponseOutputMessageContent": { "OpenAIResponseOutputMessageContent": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText"
},
{
"$ref": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText",
"refusal": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
}
},
"OpenAIResponseOutputMessageContentOutputText": {
"type": "object", "type": "object",
"properties": { "properties": {
"text": { "text": {
@ -10291,28 +10330,6 @@
"title": "OpenAIResponseContentPartReasoningText", "title": "OpenAIResponseContentPartReasoningText",
"description": "Reasoning text emitted as part of a streamed response." "description": "Reasoning text emitted as part of a streamed response."
}, },
"OpenAIResponseContentPartRefusal": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "refusal",
"default": "refusal",
"description": "Content part type identifier, always \"refusal\""
},
"refusal": {
"type": "string",
"description": "Refusal text supplied by the model"
}
},
"additionalProperties": false,
"required": [
"type",
"refusal"
],
"title": "OpenAIResponseContentPartRefusal",
"description": "Refusal content within a streamed response part."
},
"OpenAIResponseObjectStream": { "OpenAIResponseObjectStream": {
"oneOf": [ "oneOf": [
{ {

View file

@ -6551,6 +6551,25 @@ components:
url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation' url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation'
container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation' container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation'
file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath' file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath'
OpenAIResponseContentPartRefusal:
type: object
properties:
type:
type: string
const: refusal
default: refusal
description: >-
Content part type identifier, always "refusal"
refusal:
type: string
description: Refusal text supplied by the model
additionalProperties: false
required:
- type
- refusal
title: OpenAIResponseContentPartRefusal
description: >-
Refusal content within a streamed response part.
OpenAIResponseError: OpenAIResponseError:
type: object type: object
properties: properties:
@ -6972,6 +6991,15 @@ components:
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest' mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
OpenAIResponseOutputMessageContent: OpenAIResponseOutputMessageContent:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
- $ref: '#/components/schemas/OpenAIResponseContentPartRefusal'
discriminator:
propertyName: type
mapping:
output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
refusal: '#/components/schemas/OpenAIResponseContentPartRefusal'
"OpenAIResponseOutputMessageContentOutputText":
type: object type: object
properties: properties:
text: text:
@ -7663,25 +7691,6 @@ components:
title: OpenAIResponseContentPartReasoningText title: OpenAIResponseContentPartReasoningText
description: >- description: >-
Reasoning text emitted as part of a streamed response. Reasoning text emitted as part of a streamed response.
OpenAIResponseContentPartRefusal:
type: object
properties:
type:
type: string
const: refusal
default: refusal
description: >-
Content part type identifier, always "refusal"
refusal:
type: string
description: Refusal text supplied by the model
additionalProperties: false
required:
- type
- refusal
title: OpenAIResponseContentPartRefusal
description: >-
Refusal content within a streamed response part.
OpenAIResponseObjectStream: OpenAIResponseObjectStream:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'

View file

@ -5858,6 +5858,28 @@
} }
} }
}, },
"OpenAIResponseContentPartRefusal": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "refusal",
"default": "refusal",
"description": "Content part type identifier, always \"refusal\""
},
"refusal": {
"type": "string",
"description": "Refusal text supplied by the model"
}
},
"additionalProperties": false,
"required": [
"type",
"refusal"
],
"title": "OpenAIResponseContentPartRefusal",
"description": "Refusal content within a streamed response part."
},
"OpenAIResponseInputMessageContent": { "OpenAIResponseInputMessageContent": {
"oneOf": [ "oneOf": [
{ {
@ -6001,6 +6023,23 @@
"description": "Corresponds to the various Message types in the Responses API. They are all under one type because the Responses API gives them all the same \"type\" value, and there is no way to tell them apart in certain scenarios." "description": "Corresponds to the various Message types in the Responses API. They are all under one type because the Responses API gives them all the same \"type\" value, and there is no way to tell them apart in certain scenarios."
}, },
"OpenAIResponseOutputMessageContent": { "OpenAIResponseOutputMessageContent": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText"
},
{
"$ref": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText",
"refusal": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
}
},
"OpenAIResponseOutputMessageContentOutputText": {
"type": "object", "type": "object",
"properties": { "properties": {
"text": { "text": {
@ -8386,28 +8425,6 @@
"title": "OpenAIResponseContentPartReasoningText", "title": "OpenAIResponseContentPartReasoningText",
"description": "Reasoning text emitted as part of a streamed response." "description": "Reasoning text emitted as part of a streamed response."
}, },
"OpenAIResponseContentPartRefusal": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "refusal",
"default": "refusal",
"description": "Content part type identifier, always \"refusal\""
},
"refusal": {
"type": "string",
"description": "Refusal text supplied by the model"
}
},
"additionalProperties": false,
"required": [
"type",
"refusal"
],
"title": "OpenAIResponseContentPartRefusal",
"description": "Refusal content within a streamed response part."
},
"OpenAIResponseObjectStream": { "OpenAIResponseObjectStream": {
"oneOf": [ "oneOf": [
{ {

View file

@ -4416,6 +4416,25 @@ components:
url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation' url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation'
container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation' container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation'
file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath' file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath'
OpenAIResponseContentPartRefusal:
type: object
properties:
type:
type: string
const: refusal
default: refusal
description: >-
Content part type identifier, always "refusal"
refusal:
type: string
description: Refusal text supplied by the model
additionalProperties: false
required:
- type
- refusal
title: OpenAIResponseContentPartRefusal
description: >-
Refusal content within a streamed response part.
OpenAIResponseInputMessageContent: OpenAIResponseInputMessageContent:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseInputMessageContentText' - $ref: '#/components/schemas/OpenAIResponseInputMessageContentText'
@ -4515,6 +4534,15 @@ components:
under one type because the Responses API gives them all the same "type" value, under one type because the Responses API gives them all the same "type" value,
and there is no way to tell them apart in certain scenarios. and there is no way to tell them apart in certain scenarios.
OpenAIResponseOutputMessageContent: OpenAIResponseOutputMessageContent:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
- $ref: '#/components/schemas/OpenAIResponseContentPartRefusal'
discriminator:
propertyName: type
mapping:
output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
refusal: '#/components/schemas/OpenAIResponseContentPartRefusal'
"OpenAIResponseOutputMessageContentOutputText":
type: object type: object
properties: properties:
text: text:
@ -6359,25 +6387,6 @@ components:
title: OpenAIResponseContentPartReasoningText title: OpenAIResponseContentPartReasoningText
description: >- description: >-
Reasoning text emitted as part of a streamed response. Reasoning text emitted as part of a streamed response.
OpenAIResponseContentPartRefusal:
type: object
properties:
type:
type: string
const: refusal
default: refusal
description: >-
Content part type identifier, always "refusal"
refusal:
type: string
description: Refusal text supplied by the model
additionalProperties: false
required:
- type
- refusal
title: OpenAIResponseContentPartRefusal
description: >-
Refusal content within a streamed response part.
OpenAIResponseObjectStream: OpenAIResponseObjectStream:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'

View file

@ -7867,6 +7867,28 @@
} }
} }
}, },
"OpenAIResponseContentPartRefusal": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "refusal",
"default": "refusal",
"description": "Content part type identifier, always \"refusal\""
},
"refusal": {
"type": "string",
"description": "Refusal text supplied by the model"
}
},
"additionalProperties": false,
"required": [
"type",
"refusal"
],
"title": "OpenAIResponseContentPartRefusal",
"description": "Refusal content within a streamed response part."
},
"OpenAIResponseInputMessageContent": { "OpenAIResponseInputMessageContent": {
"oneOf": [ "oneOf": [
{ {
@ -8010,6 +8032,23 @@
"description": "Corresponds to the various Message types in the Responses API. They are all under one type because the Responses API gives them all the same \"type\" value, and there is no way to tell them apart in certain scenarios." "description": "Corresponds to the various Message types in the Responses API. They are all under one type because the Responses API gives them all the same \"type\" value, and there is no way to tell them apart in certain scenarios."
}, },
"OpenAIResponseOutputMessageContent": { "OpenAIResponseOutputMessageContent": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText"
},
{
"$ref": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText",
"refusal": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
}
},
"OpenAIResponseOutputMessageContentOutputText": {
"type": "object", "type": "object",
"properties": { "properties": {
"text": { "text": {
@ -10395,28 +10434,6 @@
"title": "OpenAIResponseContentPartReasoningText", "title": "OpenAIResponseContentPartReasoningText",
"description": "Reasoning text emitted as part of a streamed response." "description": "Reasoning text emitted as part of a streamed response."
}, },
"OpenAIResponseContentPartRefusal": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "refusal",
"default": "refusal",
"description": "Content part type identifier, always \"refusal\""
},
"refusal": {
"type": "string",
"description": "Refusal text supplied by the model"
}
},
"additionalProperties": false,
"required": [
"type",
"refusal"
],
"title": "OpenAIResponseContentPartRefusal",
"description": "Refusal content within a streamed response part."
},
"OpenAIResponseObjectStream": { "OpenAIResponseObjectStream": {
"oneOf": [ "oneOf": [
{ {

View file

@ -5861,6 +5861,25 @@ components:
url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation' url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation'
container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation' container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation'
file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath' file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath'
OpenAIResponseContentPartRefusal:
type: object
properties:
type:
type: string
const: refusal
default: refusal
description: >-
Content part type identifier, always "refusal"
refusal:
type: string
description: Refusal text supplied by the model
additionalProperties: false
required:
- type
- refusal
title: OpenAIResponseContentPartRefusal
description: >-
Refusal content within a streamed response part.
OpenAIResponseInputMessageContent: OpenAIResponseInputMessageContent:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseInputMessageContentText' - $ref: '#/components/schemas/OpenAIResponseInputMessageContentText'
@ -5960,6 +5979,15 @@ components:
under one type because the Responses API gives them all the same "type" value, under one type because the Responses API gives them all the same "type" value,
and there is no way to tell them apart in certain scenarios. and there is no way to tell them apart in certain scenarios.
OpenAIResponseOutputMessageContent: OpenAIResponseOutputMessageContent:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
- $ref: '#/components/schemas/OpenAIResponseContentPartRefusal'
discriminator:
propertyName: type
mapping:
output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
refusal: '#/components/schemas/OpenAIResponseContentPartRefusal'
"OpenAIResponseOutputMessageContentOutputText":
type: object type: object
properties: properties:
text: text:
@ -7804,25 +7832,6 @@ components:
title: OpenAIResponseContentPartReasoningText title: OpenAIResponseContentPartReasoningText
description: >- description: >-
Reasoning text emitted as part of a streamed response. Reasoning text emitted as part of a streamed response.
OpenAIResponseContentPartRefusal:
type: object
properties:
type:
type: string
const: refusal
default: refusal
description: >-
Content part type identifier, always "refusal"
refusal:
type: string
description: Refusal text supplied by the model
additionalProperties: false
required:
- type
- refusal
title: OpenAIResponseContentPartRefusal
description: >-
Refusal content within a streamed response part.
OpenAIResponseObjectStream: OpenAIResponseObjectStream:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'

View file

@ -131,8 +131,19 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list) annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
"""Refusal content within a streamed response part.
:param type: Content part type identifier, always "refusal"
:param refusal: Refusal text supplied by the model
"""
type: Literal["refusal"] = "refusal"
refusal: str
OpenAIResponseOutputMessageContent = Annotated[ OpenAIResponseOutputMessageContent = Annotated[
OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent") register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
@ -878,18 +889,6 @@ class OpenAIResponseContentPartOutputText(BaseModel):
logprobs: list[dict[str, Any]] | None = None logprobs: list[dict[str, Any]] | None = None
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
"""Refusal content within a streamed response part.
:param type: Content part type identifier, always "refusal"
:param refusal: Refusal text supplied by the model
"""
type: Literal["refusal"] = "refusal"
refusal: str
@json_schema_type @json_schema_type
class OpenAIResponseContentPartReasoningText(BaseModel): class OpenAIResponseContentPartReasoningText(BaseModel):
"""Reasoning text emitted as part of a streamed response. """Reasoning text emitted as part of a streamed response.

View file

@ -53,6 +53,11 @@ from llama_stack.core.stack import (
cast_image_name_to_string, cast_image_name_to_string,
replace_env_vars, replace_env_vars,
) )
from llama_stack.core.testing_context import (
TEST_CONTEXT,
reset_test_context,
sync_test_context_from_provider_data,
)
from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.context import preserve_contexts_async_generator
@ -244,12 +249,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
# Use context manager with both provider data and auth attributes # Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user): with request_provider_data_context(request.headers, user):
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"): if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
from llama_stack.core.testing_context import (
TEST_CONTEXT,
reset_test_context,
sync_test_context_from_provider_data,
)
test_context_token = sync_test_context_from_provider_data() test_context_token = sync_test_context_from_provider_data()
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)

View file

@ -92,6 +92,7 @@ class MetaReferenceAgentsImpl(Agents):
responses_store=self.responses_store, responses_store=self.responses_store,
vector_io_api=self.vector_io_api, vector_io_api=self.vector_io_api,
conversations_api=self.conversations_api, conversations_api=self.conversations_api,
safety_api=self.safety_api,
) )
async def create_agent( async def create_agent(

View file

@ -15,12 +15,15 @@ from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem, ListOpenAIResponseInputItem,
ListOpenAIResponseObject, ListOpenAIResponseObject,
OpenAIDeleteResponseObject, OpenAIDeleteResponseObject,
OpenAIResponseContentPartRefusal,
OpenAIResponseInput, OpenAIResponseInput,
OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseMessage, OpenAIResponseMessage,
OpenAIResponseObject, OpenAIResponseObject,
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseText, OpenAIResponseText,
OpenAIResponseTextFormat, OpenAIResponseTextFormat,
) )
@ -31,9 +34,11 @@ from llama_stack.apis.conversations import Conversations
from llama_stack.apis.conversations.conversations import ConversationItem from llama_stack.apis.conversations.conversations import ConversationItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
Message,
OpenAIMessageParam, OpenAIMessageParam,
OpenAISystemMessageParam, OpenAISystemMessageParam,
) )
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -42,12 +47,15 @@ from llama_stack.providers.utils.responses.responses_store import (
_OpenAIResponseObjectWithInputAndMessages, _OpenAIResponseObjectWithInputAndMessages,
) )
from ..safety import SafetyException
from .streaming import StreamingResponseOrchestrator from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor from .tool_executor import ToolExecutor
from .types import ChatCompletionContext, ToolContext from .types import ChatCompletionContext, ToolContext
from .utils import ( from .utils import (
convert_response_input_to_chat_messages, convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format, convert_response_text_to_chat_response_format,
extract_shield_ids,
run_multiple_shields,
) )
logger = get_logger(name=__name__, category="openai_responses") logger = get_logger(name=__name__, category="openai_responses")
@ -67,6 +75,7 @@ class OpenAIResponsesImpl:
responses_store: ResponsesStore, responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO vector_io_api: VectorIO, # VectorIO
conversations_api: Conversations, conversations_api: Conversations,
safety_api: Safety,
): ):
self.inference_api = inference_api self.inference_api = inference_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
@ -74,6 +83,7 @@ class OpenAIResponsesImpl:
self.responses_store = responses_store self.responses_store = responses_store
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api
self.conversations_api = conversations_api self.conversations_api = conversations_api
self.safety_api = safety_api
self.tool_executor = ToolExecutor( self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api, tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api, tool_runtime_api=tool_runtime_api,
@ -225,9 +235,7 @@ class OpenAIResponsesImpl:
stream = bool(stream) stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
# Shields parameter received via extra_body - not yet implemented shield_ids = extract_shield_ids(shields) if shields else []
if shields is not None:
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
if conversation is not None and previous_response_id is not None: if conversation is not None and previous_response_id is not None:
raise ValueError( raise ValueError(
@ -255,6 +263,7 @@ class OpenAIResponsesImpl:
text=text, text=text,
tools=tools, tools=tools,
max_infer_iters=max_infer_iters, max_infer_iters=max_infer_iters,
shield_ids=shield_ids,
) )
if stream: if stream:
@ -288,6 +297,42 @@ class OpenAIResponsesImpl:
raise ValueError("The response stream never reached a terminal state") raise ValueError("The response stream never reached a terminal state")
return final_response return final_response
async def _check_input_safety(
self, messages: list[Message], shield_ids: list[str]
) -> OpenAIResponseContentPartRefusal | None:
"""Validate input messages against shields. Returns refusal content if violation found."""
try:
await run_multiple_shields(self.safety_api, messages, shield_ids)
except SafetyException as e:
logger.info(f"Input shield violation: {e.violation.user_message}")
return OpenAIResponseContentPartRefusal(
refusal=e.violation.user_message or "Content blocked by safety shields"
)
async def _create_refusal_response_events(
self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Create and yield refusal response events following the established streaming pattern."""
# Create initial response and yield created event
initial_response = OpenAIResponseObject(
id=response_id,
created_at=created_at,
model=model,
status="in_progress",
output=[],
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# Create completed refusal response using OpenAIResponseContentPartRefusal
refusal_response = OpenAIResponseObject(
id=response_id,
created_at=created_at,
model=model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
async def _create_streaming_response( async def _create_streaming_response(
self, self,
input: str | list[OpenAIResponseInput], input: str | list[OpenAIResponseInput],
@ -301,6 +346,7 @@ class OpenAIResponsesImpl:
text: OpenAIResponseText | None = None, text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
shield_ids: list[str] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]: ) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing # Input preprocessing
all_input, messages, tool_context = await self._process_input_with_previous_response( all_input, messages, tool_context = await self._process_input_with_previous_response(
@ -333,8 +379,11 @@ class OpenAIResponsesImpl:
text=text, text=text,
max_infer_iters=max_infer_iters, max_infer_iters=max_infer_iters,
tool_executor=self.tool_executor, tool_executor=self.tool_executor,
safety_api=self.safety_api,
shield_ids=shield_ids,
) )
# Output safety validation hook - delegated to streaming orchestrator for real-time validation
# Stream the response # Stream the response
final_response = None final_response = None
failed_response = None failed_response = None

View file

@ -13,10 +13,12 @@ from llama_stack.apis.agents.openai_responses import (
ApprovalFilter, ApprovalFilter,
MCPListToolsTool, MCPListToolsTool,
OpenAIResponseContentPartOutputText, OpenAIResponseContentPartOutputText,
OpenAIResponseContentPartRefusal,
OpenAIResponseError, OpenAIResponseError,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseInputToolMCP, OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest, OpenAIResponseMCPApprovalRequest,
OpenAIResponseMessage,
OpenAIResponseObject, OpenAIResponseObject,
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCompleted,
@ -45,6 +47,7 @@ from llama_stack.apis.agents.openai_responses import (
WebSearchToolTypes, WebSearchToolTypes,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage,
Inference, Inference,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
@ -52,12 +55,18 @@ from llama_stack.apis.inference import (
OpenAIChatCompletionToolCall, OpenAIChatCompletionToolCall,
OpenAIChoice, OpenAIChoice,
OpenAIMessageParam, OpenAIMessageParam,
StopReason,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing
from ..safety import SafetyException
from .types import ChatCompletionContext, ChatCompletionResult from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call from .utils import (
convert_chat_choice_to_response_message,
is_function_tool_call,
run_multiple_shields,
)
logger = get_logger(name=__name__, category="agents::meta_reference") logger = get_logger(name=__name__, category="agents::meta_reference")
@ -93,6 +102,8 @@ class StreamingResponseOrchestrator:
text: OpenAIResponseText, text: OpenAIResponseText,
max_infer_iters: int, max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class tool_executor, # Will be the tool execution logic from the main class
safety_api,
shield_ids: list[str] | None = None,
): ):
self.inference_api = inference_api self.inference_api = inference_api
self.ctx = ctx self.ctx = ctx
@ -101,6 +112,8 @@ class StreamingResponseOrchestrator:
self.text = text self.text = text
self.max_infer_iters = max_infer_iters self.max_infer_iters = max_infer_iters
self.tool_executor = tool_executor self.tool_executor = tool_executor
self.safety_api = safety_api
self.shield_ids = shield_ids or []
self.sequence_number = 0 self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing # Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {} self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
@ -110,6 +123,61 @@ class StreamingResponseOrchestrator:
self.citation_files: dict[str, str] = {} self.citation_files: dict[str, str] = {}
# Track accumulated usage across all inference calls # Track accumulated usage across all inference calls
self.accumulated_usage: OpenAIResponseUsage | None = None self.accumulated_usage: OpenAIResponseUsage | None = None
# Track if we've sent a refusal response
self.violation_detected = False
async def _check_input_safety(self, messages: list[OpenAIMessageParam]) -> OpenAIResponseContentPartRefusal | None:
"""Validate input messages against shields. Returns refusal content if violation found."""
try:
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
except SafetyException as e:
logger.info(f"Input shield violation: {e.violation.user_message}")
return OpenAIResponseContentPartRefusal(
refusal=e.violation.user_message or "Content blocked by safety shields"
)
async def _create_input_refusal_response_events(
self, refusal_content: OpenAIResponseContentPartRefusal
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Create refusal response events for input safety violations."""
# Create the refusal content part explicitly with the correct structure
refusal_part = OpenAIResponseContentPartRefusal(refusal=refusal_content.refusal, type="refusal")
refusal_response = OpenAIResponseObject(
id=self.response_id,
created_at=self.created_at,
model=self.ctx.model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_part], type="message")],
)
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
"""Check accumulated streaming text content against shields. Returns violation message if blocked."""
if not self.shield_ids or not accumulated_text:
return None
messages = [CompletionMessage(content=accumulated_text, stop_reason=StopReason.end_of_turn)]
try:
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
except SafetyException as e:
logger.info(f"Output shield violation: {e.violation.user_message}")
return e.violation.user_message or "Generated content blocked by safety shields"
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
# Create a completed refusal response
refusal_response = OpenAIResponseObject(
id=self.response_id,
created_at=self.created_at,
model=self.ctx.model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]: def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
cloned: list[OpenAIResponseOutput] = [] cloned: list[OpenAIResponseOutput] = []
@ -154,6 +222,15 @@ class StreamingResponseOrchestrator:
sequence_number=self.sequence_number, sequence_number=self.sequence_number,
) )
# Input safety validation - check messages before processing
if self.shield_ids:
input_refusal = await self._check_input_safety(self.ctx.messages)
if input_refusal:
# Return refusal response immediately
async for refusal_event in self._create_input_refusal_response_events(input_refusal):
yield refusal_event
return
async for stream_event in self._process_tools(output_messages): async for stream_event in self._process_tools(output_messages):
yield stream_event yield stream_event
@ -187,6 +264,10 @@ class StreamingResponseOrchestrator:
completion_result_data = stream_event_or_result completion_result_data = stream_event_or_result
else: else:
yield stream_event_or_result yield stream_event_or_result
# If violation detected, skip the rest of processing since we already sent refusal
if self.violation_detected:
return
if not completion_result_data: if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data") raise ValueError("Streaming chunk processor failed to return completion data")
last_completion_result = completion_result_data last_completion_result = completion_result_data
@ -475,6 +556,15 @@ class StreamingResponseOrchestrator:
response_tool_call.function.arguments or "" response_tool_call.function.arguments or ""
) + tool_call.function.arguments ) + tool_call.function.arguments
# Safety check after processing all chunks
if chat_response_content:
accumulated_text = "".join(chat_response_content)
violation_message = await self._check_output_stream_chunk_safety(accumulated_text)
if violation_message:
yield await self._create_refusal_response(violation_message)
self.violation_detected = True
return
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls) # Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
for tool_call_index in sorted(chat_response_tool_calls.keys()): for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call = chat_response_tool_calls[tool_call_index] tool_call = chat_response_tool_calls[tool_call_index]

View file

@ -7,6 +7,7 @@
import re import re
import uuid import uuid
from llama_stack.apis.agents.agents import ResponseShieldSpec
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation, OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput, OpenAIResponseInput,
@ -26,6 +27,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText, OpenAIResponseText,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Message,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam, OpenAIChatCompletionContentPartParam,
@ -45,6 +47,7 @@ from llama_stack.apis.inference import (
OpenAIToolMessageParam, OpenAIToolMessageParam,
OpenAIUserMessageParam, OpenAIUserMessageParam,
) )
from llama_stack.apis.safety import Safety
async def convert_chat_choice_to_response_message( async def convert_chat_choice_to_response_message(
@ -171,7 +174,7 @@ async def convert_response_input_to_chat_messages(
pass pass
else: else:
content = await convert_response_content_to_chat_content(input_item.content) content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role) message_type = get_message_type_by_role(input_item.role)
if message_type is None: if message_type is None:
raise ValueError( raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
@ -240,7 +243,8 @@ async def convert_response_text_to_chat_response_format(
raise ValueError(f"Unsupported text format: {text.format}") raise ValueError(f"Unsupported text format: {text.format}")
async def get_message_type_by_role(role: str): def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
"""Get the appropriate OpenAI message parameter type for a given role."""
role_to_type = { role_to_type = {
"user": OpenAIUserMessageParam, "user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam, "system": OpenAISystemMessageParam,
@ -307,3 +311,52 @@ def is_function_tool_call(
if t.type == "function" and t.name == tool_call.function.name: if t.type == "function" and t.name == tool_call.function.name:
return True return True
return False return False
async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None:
"""Run multiple shields against messages and raise SafetyException for violations."""
if not shield_ids or not messages:
return
for shield_id in shield_ids:
response = await safety_api.run_shield(
shield_id=shield_id,
messages=messages,
params={},
)
if response.violation and response.violation.violation_level.name == "ERROR":
from ..safety import SafetyException
raise SafetyException(response.violation)
def extract_shield_ids(shields: list | None) -> list[str]:
"""Extract shield IDs from shields parameter, handling both string IDs and ResponseShieldSpec objects."""
if not shields:
return []
shield_ids = []
for shield in shields:
if isinstance(shield, str):
shield_ids.append(shield)
elif isinstance(shield, ResponseShieldSpec):
shield_ids.append(shield.type)
else:
raise ValueError(f"Unsupported shield type: {type(shield)}")
return shield_ids
def extract_text_content(content: str | list | None) -> str | None:
"""Extract text content from OpenAI message content (string or complex structure)."""
if isinstance(content, str):
return content
elif isinstance(content, list):
# Handle complex content - extract text parts only
text_parts = []
for part in content:
if hasattr(part, "text"):
text_parts.append(part.text)
elif hasattr(part, "type") and part.type == "refusal":
# Skip refusal parts - don't validate them again
continue
return " ".join(text_parts) if text_parts else None
return None

View file

@ -247,12 +247,17 @@ class LlamaGuardShield:
self.safety_categories = safety_categories self.safety_categories = safety_categories
def check_unsafe_response(self, response: str) -> str | None: def check_unsafe_response(self, response: str) -> str | None:
# Check for "unsafe\n<code>" format
match = re.match(r"^unsafe\n(.*)$", response) match = re.match(r"^unsafe\n(.*)$", response)
if match: if match:
# extracts the unsafe code # extracts the unsafe code
extracted = match.group(1) extracted = match.group(1)
return extracted return extracted
# Check for direct category code format (e.g., "S1", "S2", etc.)
if re.match(r"^S\d+$", response):
return response
return None return None
def get_safety_categories(self) -> list[str]: def get_safety_categories(self) -> list[str]:

View file

@ -25,13 +25,13 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"aiohttp", "aiohttp",
"fastapi>=0.115.0,<1.0", # server "fastapi>=0.115.0,<1.0", # server
"fire", # for MCP in LLS client "fire", # for MCP in LLS client
"httpx", "httpx",
"jinja2>=3.1.6", "jinja2>=3.1.6",
"jsonschema", "jsonschema",
"llama-stack-client>=0.2.23", "llama-stack-client>=0.2.23",
"openai>=1.107", # for expires_after support "openai>=1.107", # for expires_after support
"prompt-toolkit", "prompt-toolkit",
"python-dotenv", "python-dotenv",
"python-jose[cryptography]", "python-jose[cryptography]",
@ -42,13 +42,13 @@ dependencies = [
"tiktoken", "tiktoken",
"pillow", "pillow",
"h11>=0.16.0", "h11>=0.16.0",
"python-multipart>=0.0.20", # For fastapi Form "python-multipart>=0.0.20", # For fastapi Form
"uvicorn>=0.34.0", # server "uvicorn>=0.34.0", # server
"opentelemetry-sdk>=1.30.0", # server "opentelemetry-sdk>=1.30.0", # server
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server "opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
"aiosqlite>=0.21.0", # server - for metadata store "aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store "asyncpg", # for metadata store
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations "sqlalchemy[asyncio]>=2.0.41", # server - for conversations
] ]
[project.optional-dependencies] [project.optional-dependencies]
@ -161,6 +161,7 @@ explicit = true
[tool.uv.sources] [tool.uv.sources]
torch = [{ index = "pytorch-cpu" }] torch = [{ index = "pytorch-cpu" }]
torchvision = [{ index = "pytorch-cpu" }] torchvision = [{ index = "pytorch-cpu" }]
llama-stack-client = { path = "../llama-stack-client-python" }
[tool.ruff] [tool.ruff]
line-length = 120 line-length = 120

View file

@ -297,3 +297,130 @@ def test_function_call_output_response_with_none_arguments(openai_client, client
assert response.output[0].type == "function_call" assert response.output[0].type == "function_call"
assert response.output[0].arguments == "{}" assert response.output[0].arguments == "{}"
_ = response.output[0].call_id _ = response.output[0].call_id
@pytest.mark.parametrize("stream", [True, False])
def test_input_safety_shields(compat_client, text_model_id, stream):
"""Test that input safety shields block inappropriate content."""
if not isinstance(compat_client, OpenAI):
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
# Test with shield that should block inappropriate content
# Note: This test assumes shields are configured in the test environment
inappropriate_message = "How can I hurt someone?"
response = compat_client.responses.create(
model=text_model_id,
input=[
{
"role": "user",
"content": inappropriate_message,
}
],
stream=stream,
extra_body={"shields": ["llama-guard"]}, # Shields passed via extra_body
)
if stream:
# Input violations cause immediate refusal - only one completion chunk
final_response = None
for chunk in response:
if chunk.type == "response.completed":
final_response = chunk.response
break
assert final_response is not None
response_to_check = final_response
else:
response_to_check = response
# Check if the response contains a refusal (indicating shield blocked content)
assert response_to_check.output[0].type == "message"
message = response_to_check.output[0]
# Shield should ALWAYS block inappropriate input - expect refusal response
assert len(message.content) > 0, "Content should not be empty"
content_item = message.content[0]
assert content_item.type == "refusal", f"Expected refusal content, but got: {content_item.type}"
assert len(content_item.refusal.strip()) > 0, "Refusal message should not be empty"
@pytest.mark.parametrize("stream", [True, False])
def test_output_safety_shields(compat_client, text_model_id, stream):
"""Test that output safety shields can block generated content."""
if not isinstance(compat_client, OpenAI):
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
# Use an ambiguous prompt that could potentially generate content requiring shield evaluation
# This prompt: "Describe self-defense techniques."
# Could result in:
# 1. Educational safety content → shield allows output → normal response
# 2. Violent techniques content → shield blocks output → refusal response
response = compat_client.responses.create(
model=text_model_id,
input=[
{
"role": "user",
"content": "Describe violent content.",
}
],
stream=stream,
extra_body={"shields": ["llama-guard"]}, # Output shield validation
)
if stream:
# For streaming, violations cause stream to end with refusal completion
last_chunk = None
for chunk in response:
last_chunk = chunk
assert last_chunk.type == "response.completed", f"Expected final chunk to be completion, got {last_chunk.type}"
response_to_check = last_chunk.response
else:
response_to_check = response
# Verify we get a proper response (this test mainly verifies the shield integration works)
assert response_to_check.output[0].type == "message"
message = response_to_check.output[0]
assert len(message.content) > 0, "Message should have content"
content_item = message.content[0]
assert content_item.type == "refusal", f"Content type should be 'refusal', got {content_item.type}"
def test_shields_with_tools(compat_client, text_model_id):
"""Test that shields work correctly when tools are present."""
if not isinstance(compat_client, OpenAI):
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
response = compat_client.responses.create(
model=text_model_id,
input=[
{
"role": "user",
"content": "What's the weather like? Please help me in a safe and appropriate way.",
}
],
tools=[
{
"type": "function",
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city to get the weather for"},
},
},
}
],
extra_body={"shields": ["llama-guard"]},
stream=False,
)
# Verify response completes successfully with tools and shields
assert response.id is not None
assert len(response.output) > 0
# Response should be either a function call or a message
output_type = response.output[0].type
assert output_type in ["function_call", "message"]

View file

@ -18,6 +18,7 @@ from openai.types.chat.chat_completion_chunk import (
from llama_stack.apis.agents import Order from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem, ListOpenAIResponseInputItem,
OpenAIResponseContentPartRefusal,
OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction, OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP, OpenAIResponseInputToolMCP,
@ -38,7 +39,9 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatJSONObject, OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema, OpenAIResponseFormatJSONSchema,
OpenAIUserMessageParam, OpenAIUserMessageParam,
UserMessage,
) )
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.core.datatypes import ResponsesStoreConfig
@ -90,6 +93,12 @@ def mock_conversations_api():
return mock_api return mock_api
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
return safety_api
@pytest.fixture @pytest.fixture
def openai_responses_impl( def openai_responses_impl(
mock_inference_api, mock_inference_api,
@ -98,6 +107,7 @@ def openai_responses_impl(
mock_responses_store, mock_responses_store,
mock_vector_io_api, mock_vector_io_api,
mock_conversations_api, mock_conversations_api,
mock_safety_api,
): ):
return OpenAIResponsesImpl( return OpenAIResponsesImpl(
inference_api=mock_inference_api, inference_api=mock_inference_api,
@ -106,6 +116,7 @@ def openai_responses_impl(
responses_store=mock_responses_store, responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api, vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api, conversations_api=mock_conversations_api,
safety_api=mock_safety_api,
) )
@ -1082,3 +1093,52 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
model=model, model=model,
text=OpenAIResponseText(format={"type": "invalid"}), text=OpenAIResponseText(format={"type": "invalid"}),
) )
async def test_check_input_safety_no_violation(openai_responses_impl):
"""Test input shield validation with no violations."""
messages = [UserMessage(content="Hello world")]
shield_ids = ["llama-guard"]
# Mock successful shield validation (no violation)
mock_response = AsyncMock()
mock_response.violation = None
openai_responses_impl.safety_api.run_shield.return_value = mock_response
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
assert result is None
openai_responses_impl.safety_api.run_shield.assert_called_once_with(
shield_id="llama-guard", messages=messages, params={}
)
async def test_check_input_safety_with_violation(openai_responses_impl):
"""Test input shield validation with safety violation."""
messages = [UserMessage(content="Harmful content")]
shield_ids = ["llama-guard"]
# Mock shield violation
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
)
mock_response = AsyncMock()
mock_response.violation = violation
openai_responses_impl.safety_api.run_shield.return_value = mock_response
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
assert isinstance(result, OpenAIResponseContentPartRefusal)
assert result.refusal == "Content violates safety guidelines"
assert result.type == "refusal"
async def test_check_input_safety_empty_inputs(openai_responses_impl):
"""Test input shield validation with empty inputs."""
# Test empty shield_ids
result = await openai_responses_impl._check_input_safety([UserMessage(content="test")], [])
assert result is None
# Test empty messages
result = await openai_responses_impl._check_input_safety([], ["llama-guard"])
assert result is None

View file

@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.agents.agents import ResponseShieldSpec
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
extract_shield_ids,
extract_text_content,
)
@pytest.fixture
def mock_apis():
"""Create mock APIs for testing."""
return {
"inference_api": AsyncMock(),
"tool_groups_api": AsyncMock(),
"tool_runtime_api": AsyncMock(),
"responses_store": AsyncMock(),
"vector_io_api": AsyncMock(),
"conversations_api": AsyncMock(),
"safety_api": AsyncMock(),
}
@pytest.fixture
def responses_impl(mock_apis):
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
return OpenAIResponsesImpl(**mock_apis)
# ============================================================================
# Shield ID Extraction Tests
# ============================================================================
def test_extract_shield_ids_from_strings(responses_impl):
"""Test extraction from simple string shield IDs."""
shields = ["llama-guard", "content-filter", "nsfw-detector"]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_shield_ids_from_objects(responses_impl):
"""Test extraction from ResponseShieldSpec objects."""
shields = [
ResponseShieldSpec(type="llama-guard"),
ResponseShieldSpec(type="content-filter"),
]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter"]
def test_extract_shield_ids_mixed_formats(responses_impl):
"""Test extraction from mixed string and object formats."""
shields = [
"llama-guard",
ResponseShieldSpec(type="content-filter"),
"nsfw-detector",
]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_shield_ids_none_input(responses_impl):
"""Test extraction with None input."""
result = extract_shield_ids(None)
assert result == []
def test_extract_shield_ids_empty_list(responses_impl):
"""Test extraction with empty list."""
result = extract_shield_ids([])
assert result == []
def test_extract_shield_ids_unknown_format(responses_impl):
"""Test extraction with unknown shield format raises ValueError."""
# Create an object that's neither string nor ResponseShieldSpec
unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec
shields = ["valid-shield", unknown_object, "another-shield"]
with pytest.raises(ValueError, match="Unsupported shield type"):
extract_shield_ids(shields)
# ============================================================================
# Text Content Extraction Tests
# ============================================================================
def test_extract_text_content_string(responses_impl):
"""Test extraction from simple string content."""
content = "Hello world"
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_list_with_text(responses_impl):
"""Test extraction from list content with text parts."""
content = [
MagicMock(text="Hello "),
MagicMock(text="world"),
]
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_list_with_refusal(responses_impl):
"""Test extraction skips refusal parts."""
# Create text parts
text_part1 = MagicMock()
text_part1.text = "Hello"
text_part2 = MagicMock()
text_part2.text = "world"
# Create refusal part (no text attribute)
refusal_part = MagicMock()
refusal_part.type = "refusal"
refusal_part.refusal = "Blocked"
del refusal_part.text # Remove text attribute
content = [text_part1, refusal_part, text_part2]
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_empty_list(responses_impl):
"""Test extraction from empty list returns None."""
content = []
result = extract_text_content(content)
assert result is None
def test_extract_text_content_no_text_parts(responses_impl):
"""Test extraction with no text parts returns None."""
# Create image part (no text attribute)
image_part = MagicMock()
image_part.type = "image"
image_part.image_url = "http://example.com"
# Create refusal part (no text attribute)
refusal_part = MagicMock()
refusal_part.type = "refusal"
refusal_part.refusal = "Blocked"
# Explicitly remove text attributes to simulate non-text parts
if hasattr(image_part, "text"):
delattr(image_part, "text")
if hasattr(refusal_part, "text"):
delattr(refusal_part, "text")
content = [image_part, refusal_part]
result = extract_text_content(content)
assert result is None
def test_extract_text_content_none_input(responses_impl):
"""Test extraction with None input returns None."""
result = extract_text_content(None)
assert result is None

51
uv.lock generated
View file

@ -1897,8 +1897,8 @@ requires-dist = [
{ name = "httpx" }, { name = "httpx" },
{ name = "jinja2", specifier = ">=3.1.6" }, { name = "jinja2", specifier = ">=3.1.6" },
{ name = "jsonschema" }, { name = "jsonschema" },
{ name = "llama-stack-client", specifier = ">=0.2.23" }, { name = "llama-stack-client", directory = "../llama-stack-client-python" },
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" }, { name = "llama-stack-client", marker = "extra == 'ui'", directory = "../llama-stack-client-python" },
{ name = "openai", specifier = ">=1.107" }, { name = "openai", specifier = ">=1.107" },
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" }, { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" }, { name = "opentelemetry-sdk", specifier = ">=1.30.0" },
@ -2004,8 +2004,8 @@ unit = [
[[package]] [[package]]
name = "llama-stack-client" name = "llama-stack-client"
version = "0.2.23" version = "0.3.0a3"
source = { registry = "https://pypi.org/simple" } source = { directory = "../llama-stack-client-python" }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },
{ name = "click" }, { name = "click" },
@ -2023,10 +2023,47 @@ dependencies = [
{ name = "tqdm" }, { name = "tqdm" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/9f/8f/306d5fcf2f97b3a6251219b03c194836a2ff4e0fcc8146c9970e50a72cd3/llama_stack_client-0.2.23.tar.gz", hash = "sha256:68f34e8ac8eea6a73ed9d4977d849992b2d8bd835804d770a11843431cd5bf74", size = 322288, upload-time = "2025-09-26T21:11:08.342Z" }
wheels = [ [package.metadata]
{ url = "https://files.pythonhosted.org/packages/fa/75/3eb58e092a681804013dbec7b7f549d18f55acf6fd6e6b27de7e249766d8/llama_stack_client-0.2.23-py3-none-any.whl", hash = "sha256:eee42c74eee8f218f9455e5a06d5d4be43f8a8c82a7937ef51ce367f916df847", size = 379809, upload-time = "2025-09-26T21:11:06.856Z" }, requires-dist = [
{ name = "aiohttp", marker = "extra == 'aiohttp'" },
{ name = "anyio", specifier = ">=3.5.0,<5" },
{ name = "click" },
{ name = "distro", specifier = ">=1.7.0,<2" },
{ name = "fire" },
{ name = "httpx", specifier = ">=0.23.0,<1" },
{ name = "httpx-aiohttp", marker = "extra == 'aiohttp'", specifier = ">=0.1.8" },
{ name = "pandas" },
{ name = "prompt-toolkit" },
{ name = "pyaml" },
{ name = "pydantic", specifier = ">=1.9.0,<3" },
{ name = "requests" },
{ name = "rich" },
{ name = "sniffio" },
{ name = "termcolor" },
{ name = "tqdm" },
{ name = "typing-extensions", specifier = ">=4.7,<5" },
] ]
provides-extras = ["aiohttp"]
[package.metadata.requires-dev]
dev = [
{ name = "black" },
{ name = "dirty-equals", specifier = ">=0.6.0" },
{ name = "importlib-metadata", specifier = ">=6.7.0" },
{ name = "mypy" },
{ name = "pre-commit" },
{ name = "pyright", specifier = "==1.1.399" },
{ name = "pytest", specifier = ">=7.1.1" },
{ name = "pytest-asyncio" },
{ name = "pytest-xdist", specifier = ">=3.6.1" },
{ name = "respx" },
{ name = "rich", specifier = ">=13.7.1" },
{ name = "ruff" },
{ name = "time-machine" },
]
pydantic-v1 = [{ name = "pydantic", specifier = ">=1.9.0,<2" }]
pydantic-v2 = [{ name = "pydantic", specifier = ">=2,<3" }]
[[package]] [[package]]
name = "locust" name = "locust"