mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
feat: Add responses and safety impl with extra body
This commit is contained in:
parent
6954fe2274
commit
9152efa1a9
18 changed files with 833 additions and 164 deletions
61
docs/static/deprecated-llama-stack-spec.html
vendored
61
docs/static/deprecated-llama-stack-spec.html
vendored
|
|
@ -8821,6 +8821,28 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"OpenAIResponseContentPartRefusal": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "refusal",
|
||||
"default": "refusal",
|
||||
"description": "Content part type identifier, always \"refusal\""
|
||||
},
|
||||
"refusal": {
|
||||
"type": "string",
|
||||
"description": "Refusal text supplied by the model"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"refusal"
|
||||
],
|
||||
"title": "OpenAIResponseContentPartRefusal",
|
||||
"description": "Refusal content within a streamed response part."
|
||||
},
|
||||
"OpenAIResponseError": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
@ -9395,6 +9417,23 @@
|
|||
}
|
||||
},
|
||||
"OpenAIResponseOutputMessageContent": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseContentPartRefusal"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText",
|
||||
"refusal": "#/components/schemas/OpenAIResponseContentPartRefusal"
|
||||
}
|
||||
}
|
||||
},
|
||||
"OpenAIResponseOutputMessageContentOutputText": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
|
|
@ -10291,28 +10330,6 @@
|
|||
"title": "OpenAIResponseContentPartReasoningText",
|
||||
"description": "Reasoning text emitted as part of a streamed response."
|
||||
},
|
||||
"OpenAIResponseContentPartRefusal": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "refusal",
|
||||
"default": "refusal",
|
||||
"description": "Content part type identifier, always \"refusal\""
|
||||
},
|
||||
"refusal": {
|
||||
"type": "string",
|
||||
"description": "Refusal text supplied by the model"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"refusal"
|
||||
],
|
||||
"title": "OpenAIResponseContentPartRefusal",
|
||||
"description": "Refusal content within a streamed response part."
|
||||
},
|
||||
"OpenAIResponseObjectStream": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
|
|
|||
47
docs/static/deprecated-llama-stack-spec.yaml
vendored
47
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -6551,6 +6551,25 @@ components:
|
|||
url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation'
|
||||
container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation'
|
||||
file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath'
|
||||
OpenAIResponseContentPartRefusal:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: refusal
|
||||
default: refusal
|
||||
description: >-
|
||||
Content part type identifier, always "refusal"
|
||||
refusal:
|
||||
type: string
|
||||
description: Refusal text supplied by the model
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- refusal
|
||||
title: OpenAIResponseContentPartRefusal
|
||||
description: >-
|
||||
Refusal content within a streamed response part.
|
||||
OpenAIResponseError:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -6972,6 +6991,15 @@ components:
|
|||
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
|
||||
mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
|
||||
OpenAIResponseOutputMessageContent:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
|
||||
- $ref: '#/components/schemas/OpenAIResponseContentPartRefusal'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
|
||||
refusal: '#/components/schemas/OpenAIResponseContentPartRefusal'
|
||||
"OpenAIResponseOutputMessageContentOutputText":
|
||||
type: object
|
||||
properties:
|
||||
text:
|
||||
|
|
@ -7663,25 +7691,6 @@ components:
|
|||
title: OpenAIResponseContentPartReasoningText
|
||||
description: >-
|
||||
Reasoning text emitted as part of a streamed response.
|
||||
OpenAIResponseContentPartRefusal:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: refusal
|
||||
default: refusal
|
||||
description: >-
|
||||
Content part type identifier, always "refusal"
|
||||
refusal:
|
||||
type: string
|
||||
description: Refusal text supplied by the model
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- refusal
|
||||
title: OpenAIResponseContentPartRefusal
|
||||
description: >-
|
||||
Refusal content within a streamed response part.
|
||||
OpenAIResponseObjectStream:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||
|
|
|
|||
61
docs/static/llama-stack-spec.html
vendored
61
docs/static/llama-stack-spec.html
vendored
|
|
@ -5858,6 +5858,28 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"OpenAIResponseContentPartRefusal": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "refusal",
|
||||
"default": "refusal",
|
||||
"description": "Content part type identifier, always \"refusal\""
|
||||
},
|
||||
"refusal": {
|
||||
"type": "string",
|
||||
"description": "Refusal text supplied by the model"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"refusal"
|
||||
],
|
||||
"title": "OpenAIResponseContentPartRefusal",
|
||||
"description": "Refusal content within a streamed response part."
|
||||
},
|
||||
"OpenAIResponseInputMessageContent": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
|
@ -6001,6 +6023,23 @@
|
|||
"description": "Corresponds to the various Message types in the Responses API. They are all under one type because the Responses API gives them all the same \"type\" value, and there is no way to tell them apart in certain scenarios."
|
||||
},
|
||||
"OpenAIResponseOutputMessageContent": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseContentPartRefusal"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText",
|
||||
"refusal": "#/components/schemas/OpenAIResponseContentPartRefusal"
|
||||
}
|
||||
}
|
||||
},
|
||||
"OpenAIResponseOutputMessageContentOutputText": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
|
|
@ -8386,28 +8425,6 @@
|
|||
"title": "OpenAIResponseContentPartReasoningText",
|
||||
"description": "Reasoning text emitted as part of a streamed response."
|
||||
},
|
||||
"OpenAIResponseContentPartRefusal": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "refusal",
|
||||
"default": "refusal",
|
||||
"description": "Content part type identifier, always \"refusal\""
|
||||
},
|
||||
"refusal": {
|
||||
"type": "string",
|
||||
"description": "Refusal text supplied by the model"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"refusal"
|
||||
],
|
||||
"title": "OpenAIResponseContentPartRefusal",
|
||||
"description": "Refusal content within a streamed response part."
|
||||
},
|
||||
"OpenAIResponseObjectStream": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
|
|
|||
47
docs/static/llama-stack-spec.yaml
vendored
47
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -4416,6 +4416,25 @@ components:
|
|||
url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation'
|
||||
container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation'
|
||||
file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath'
|
||||
OpenAIResponseContentPartRefusal:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: refusal
|
||||
default: refusal
|
||||
description: >-
|
||||
Content part type identifier, always "refusal"
|
||||
refusal:
|
||||
type: string
|
||||
description: Refusal text supplied by the model
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- refusal
|
||||
title: OpenAIResponseContentPartRefusal
|
||||
description: >-
|
||||
Refusal content within a streamed response part.
|
||||
OpenAIResponseInputMessageContent:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputMessageContentText'
|
||||
|
|
@ -4515,6 +4534,15 @@ components:
|
|||
under one type because the Responses API gives them all the same "type" value,
|
||||
and there is no way to tell them apart in certain scenarios.
|
||||
OpenAIResponseOutputMessageContent:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
|
||||
- $ref: '#/components/schemas/OpenAIResponseContentPartRefusal'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
|
||||
refusal: '#/components/schemas/OpenAIResponseContentPartRefusal'
|
||||
"OpenAIResponseOutputMessageContentOutputText":
|
||||
type: object
|
||||
properties:
|
||||
text:
|
||||
|
|
@ -6359,25 +6387,6 @@ components:
|
|||
title: OpenAIResponseContentPartReasoningText
|
||||
description: >-
|
||||
Reasoning text emitted as part of a streamed response.
|
||||
OpenAIResponseContentPartRefusal:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: refusal
|
||||
default: refusal
|
||||
description: >-
|
||||
Content part type identifier, always "refusal"
|
||||
refusal:
|
||||
type: string
|
||||
description: Refusal text supplied by the model
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- refusal
|
||||
title: OpenAIResponseContentPartRefusal
|
||||
description: >-
|
||||
Refusal content within a streamed response part.
|
||||
OpenAIResponseObjectStream:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||
|
|
|
|||
61
docs/static/stainless-llama-stack-spec.html
vendored
61
docs/static/stainless-llama-stack-spec.html
vendored
|
|
@ -7867,6 +7867,28 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"OpenAIResponseContentPartRefusal": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "refusal",
|
||||
"default": "refusal",
|
||||
"description": "Content part type identifier, always \"refusal\""
|
||||
},
|
||||
"refusal": {
|
||||
"type": "string",
|
||||
"description": "Refusal text supplied by the model"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"refusal"
|
||||
],
|
||||
"title": "OpenAIResponseContentPartRefusal",
|
||||
"description": "Refusal content within a streamed response part."
|
||||
},
|
||||
"OpenAIResponseInputMessageContent": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
|
@ -8010,6 +8032,23 @@
|
|||
"description": "Corresponds to the various Message types in the Responses API. They are all under one type because the Responses API gives them all the same \"type\" value, and there is no way to tell them apart in certain scenarios."
|
||||
},
|
||||
"OpenAIResponseOutputMessageContent": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseContentPartRefusal"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText",
|
||||
"refusal": "#/components/schemas/OpenAIResponseContentPartRefusal"
|
||||
}
|
||||
}
|
||||
},
|
||||
"OpenAIResponseOutputMessageContentOutputText": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
|
|
@ -10395,28 +10434,6 @@
|
|||
"title": "OpenAIResponseContentPartReasoningText",
|
||||
"description": "Reasoning text emitted as part of a streamed response."
|
||||
},
|
||||
"OpenAIResponseContentPartRefusal": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "refusal",
|
||||
"default": "refusal",
|
||||
"description": "Content part type identifier, always \"refusal\""
|
||||
},
|
||||
"refusal": {
|
||||
"type": "string",
|
||||
"description": "Refusal text supplied by the model"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"type",
|
||||
"refusal"
|
||||
],
|
||||
"title": "OpenAIResponseContentPartRefusal",
|
||||
"description": "Refusal content within a streamed response part."
|
||||
},
|
||||
"OpenAIResponseObjectStream": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
|
|
|||
47
docs/static/stainless-llama-stack-spec.yaml
vendored
47
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -5861,6 +5861,25 @@ components:
|
|||
url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation'
|
||||
container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation'
|
||||
file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath'
|
||||
OpenAIResponseContentPartRefusal:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: refusal
|
||||
default: refusal
|
||||
description: >-
|
||||
Content part type identifier, always "refusal"
|
||||
refusal:
|
||||
type: string
|
||||
description: Refusal text supplied by the model
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- refusal
|
||||
title: OpenAIResponseContentPartRefusal
|
||||
description: >-
|
||||
Refusal content within a streamed response part.
|
||||
OpenAIResponseInputMessageContent:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputMessageContentText'
|
||||
|
|
@ -5960,6 +5979,15 @@ components:
|
|||
under one type because the Responses API gives them all the same "type" value,
|
||||
and there is no way to tell them apart in certain scenarios.
|
||||
OpenAIResponseOutputMessageContent:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
|
||||
- $ref: '#/components/schemas/OpenAIResponseContentPartRefusal'
|
||||
discriminator:
|
||||
propertyName: type
|
||||
mapping:
|
||||
output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
|
||||
refusal: '#/components/schemas/OpenAIResponseContentPartRefusal'
|
||||
"OpenAIResponseOutputMessageContentOutputText":
|
||||
type: object
|
||||
properties:
|
||||
text:
|
||||
|
|
@ -7804,25 +7832,6 @@ components:
|
|||
title: OpenAIResponseContentPartReasoningText
|
||||
description: >-
|
||||
Reasoning text emitted as part of a streamed response.
|
||||
OpenAIResponseContentPartRefusal:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: refusal
|
||||
default: refusal
|
||||
description: >-
|
||||
Content part type identifier, always "refusal"
|
||||
refusal:
|
||||
type: string
|
||||
description: Refusal text supplied by the model
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
- refusal
|
||||
title: OpenAIResponseContentPartRefusal
|
||||
description: >-
|
||||
Refusal content within a streamed response part.
|
||||
OpenAIResponseObjectStream:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||
|
|
|
|||
|
|
@ -131,8 +131,19 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
|||
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartRefusal(BaseModel):
|
||||
"""Refusal content within a streamed response part.
|
||||
:param type: Content part type identifier, always "refusal"
|
||||
:param refusal: Refusal text supplied by the model
|
||||
"""
|
||||
|
||||
type: Literal["refusal"] = "refusal"
|
||||
refusal: str
|
||||
|
||||
|
||||
OpenAIResponseOutputMessageContent = Annotated[
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
||||
|
|
@ -878,18 +889,6 @@ class OpenAIResponseContentPartOutputText(BaseModel):
|
|||
logprobs: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartRefusal(BaseModel):
|
||||
"""Refusal content within a streamed response part.
|
||||
|
||||
:param type: Content part type identifier, always "refusal"
|
||||
:param refusal: Refusal text supplied by the model
|
||||
"""
|
||||
|
||||
type: Literal["refusal"] = "refusal"
|
||||
refusal: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseContentPartReasoningText(BaseModel):
|
||||
"""Reasoning text emitted as part of a streamed response.
|
||||
|
|
|
|||
|
|
@ -53,6 +53,11 @@ from llama_stack.core.stack import (
|
|||
cast_image_name_to_string,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.core.testing_context import (
|
||||
TEST_CONTEXT,
|
||||
reset_test_context,
|
||||
sync_test_context_from_provider_data,
|
||||
)
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
|
|
@ -244,12 +249,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user):
|
||||
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE"):
|
||||
from llama_stack.core.testing_context import (
|
||||
TEST_CONTEXT,
|
||||
reset_test_context,
|
||||
sync_test_context_from_provider_data,
|
||||
)
|
||||
|
||||
test_context_token = sync_test_context_from_provider_data()
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
responses_store=self.responses_store,
|
||||
vector_io_api=self.vector_io_api,
|
||||
conversations_api=self.conversations_api,
|
||||
safety_api=self.safety_api,
|
||||
)
|
||||
|
||||
async def create_agent(
|
||||
|
|
|
|||
|
|
@ -15,12 +15,15 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
|
|
@ -31,9 +34,11 @@ from llama_stack.apis.conversations import Conversations
|
|||
from llama_stack.apis.conversations.conversations import ConversationItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
Message,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -42,12 +47,15 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
|
||||
from ..safety import SafetyException
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext, ToolContext
|
||||
from .utils import (
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
extract_shield_ids,
|
||||
run_multiple_shields,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
|
@ -67,6 +75,7 @@ class OpenAIResponsesImpl:
|
|||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
conversations_api: Conversations,
|
||||
safety_api: Safety,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
|
@ -74,6 +83,7 @@ class OpenAIResponsesImpl:
|
|||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.conversations_api = conversations_api
|
||||
self.safety_api = safety_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
tool_runtime_api=tool_runtime_api,
|
||||
|
|
@ -225,9 +235,7 @@ class OpenAIResponsesImpl:
|
|||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
# Shields parameter received via extra_body - not yet implemented
|
||||
if shields is not None:
|
||||
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
||||
shield_ids = extract_shield_ids(shields) if shields else []
|
||||
|
||||
if conversation is not None and previous_response_id is not None:
|
||||
raise ValueError(
|
||||
|
|
@ -255,6 +263,7 @@ class OpenAIResponsesImpl:
|
|||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
shield_ids=shield_ids,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -288,6 +297,42 @@ class OpenAIResponsesImpl:
|
|||
raise ValueError("The response stream never reached a terminal state")
|
||||
return final_response
|
||||
|
||||
async def _check_input_safety(
|
||||
self, messages: list[Message], shield_ids: list[str]
|
||||
) -> OpenAIResponseContentPartRefusal | None:
|
||||
"""Validate input messages against shields. Returns refusal content if violation found."""
|
||||
try:
|
||||
await run_multiple_shields(self.safety_api, messages, shield_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Input shield violation: {e.violation.user_message}")
|
||||
return OpenAIResponseContentPartRefusal(
|
||||
refusal=e.violation.user_message or "Content blocked by safety shields"
|
||||
)
|
||||
|
||||
async def _create_refusal_response_events(
|
||||
self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create and yield refusal response events following the established streaming pattern."""
|
||||
# Create initial response and yield created event
|
||||
initial_response = OpenAIResponseObject(
|
||||
id=response_id,
|
||||
created_at=created_at,
|
||||
model=model,
|
||||
status="in_progress",
|
||||
output=[],
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
# Create completed refusal response using OpenAIResponseContentPartRefusal
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=response_id,
|
||||
created_at=created_at,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
|
|
@ -301,6 +346,7 @@ class OpenAIResponsesImpl:
|
|||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shield_ids: list[str] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
|
|
@ -333,8 +379,11 @@ class OpenAIResponsesImpl:
|
|||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
safety_api=self.safety_api,
|
||||
shield_ids=shield_ids,
|
||||
)
|
||||
|
||||
# Output safety validation hook - delegated to streaming orchestrator for real-time validation
|
||||
# Stream the response
|
||||
final_response = None
|
||||
failed_response = None
|
||||
|
|
|
|||
|
|
@ -13,10 +13,12 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
ApprovalFilter,
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseError,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
|
|
@ -45,6 +47,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
|
|
@ -52,12 +55,18 @@ from llama_stack.apis.inference import (
|
|||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
StopReason,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from ..safety import SafetyException
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||
from .utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
is_function_tool_call,
|
||||
run_multiple_shields,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
|
@ -93,6 +102,8 @@ class StreamingResponseOrchestrator:
|
|||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
safety_api,
|
||||
shield_ids: list[str] | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
|
|
@ -101,6 +112,8 @@ class StreamingResponseOrchestrator:
|
|||
self.text = text
|
||||
self.max_infer_iters = max_infer_iters
|
||||
self.tool_executor = tool_executor
|
||||
self.safety_api = safety_api
|
||||
self.shield_ids = shield_ids or []
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
||||
|
|
@ -110,6 +123,61 @@ class StreamingResponseOrchestrator:
|
|||
self.citation_files: dict[str, str] = {}
|
||||
# Track accumulated usage across all inference calls
|
||||
self.accumulated_usage: OpenAIResponseUsage | None = None
|
||||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
|
||||
async def _check_input_safety(self, messages: list[OpenAIMessageParam]) -> OpenAIResponseContentPartRefusal | None:
|
||||
"""Validate input messages against shields. Returns refusal content if violation found."""
|
||||
try:
|
||||
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Input shield violation: {e.violation.user_message}")
|
||||
return OpenAIResponseContentPartRefusal(
|
||||
refusal=e.violation.user_message or "Content blocked by safety shields"
|
||||
)
|
||||
|
||||
async def _create_input_refusal_response_events(
|
||||
self, refusal_content: OpenAIResponseContentPartRefusal
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create refusal response events for input safety violations."""
|
||||
# Create the refusal content part explicitly with the correct structure
|
||||
refusal_part = OpenAIResponseContentPartRefusal(refusal=refusal_content.refusal, type="refusal")
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=self.response_id,
|
||||
created_at=self.created_at,
|
||||
model=self.ctx.model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_part], type="message")],
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
||||
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
|
||||
"""Check accumulated streaming text content against shields. Returns violation message if blocked."""
|
||||
if not self.shield_ids or not accumulated_text:
|
||||
return None
|
||||
|
||||
messages = [CompletionMessage(content=accumulated_text, stop_reason=StopReason.end_of_turn)]
|
||||
|
||||
try:
|
||||
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Output shield violation: {e.violation.user_message}")
|
||||
return e.violation.user_message or "Generated content blocked by safety shields"
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
|
||||
|
||||
# Create a completed refusal response
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=self.response_id,
|
||||
created_at=self.created_at,
|
||||
model=self.ctx.model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
|
||||
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
||||
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
|
||||
cloned: list[OpenAIResponseOutput] = []
|
||||
|
|
@ -154,6 +222,15 @@ class StreamingResponseOrchestrator:
|
|||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Input safety validation - check messages before processing
|
||||
if self.shield_ids:
|
||||
input_refusal = await self._check_input_safety(self.ctx.messages)
|
||||
if input_refusal:
|
||||
# Return refusal response immediately
|
||||
async for refusal_event in self._create_input_refusal_response_events(input_refusal):
|
||||
yield refusal_event
|
||||
return
|
||||
|
||||
async for stream_event in self._process_tools(output_messages):
|
||||
yield stream_event
|
||||
|
||||
|
|
@ -187,6 +264,10 @@ class StreamingResponseOrchestrator:
|
|||
completion_result_data = stream_event_or_result
|
||||
else:
|
||||
yield stream_event_or_result
|
||||
# If violation detected, skip the rest of processing since we already sent refusal
|
||||
if self.violation_detected:
|
||||
return
|
||||
|
||||
if not completion_result_data:
|
||||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
last_completion_result = completion_result_data
|
||||
|
|
@ -475,6 +556,15 @@ class StreamingResponseOrchestrator:
|
|||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
||||
# Safety check after processing all chunks
|
||||
if chat_response_content:
|
||||
accumulated_text = "".join(chat_response_content)
|
||||
violation_message = await self._check_output_stream_chunk_safety(accumulated_text)
|
||||
if violation_message:
|
||||
yield await self._create_refusal_response(violation_message)
|
||||
self.violation_detected = True
|
||||
return
|
||||
|
||||
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call = chat_response_tool_calls[tool_call_index]
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import re
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseShieldSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInput,
|
||||
|
|
@ -26,6 +27,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
|
|
@ -45,6 +47,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(
|
||||
|
|
@ -171,7 +174,7 @@ async def convert_response_input_to_chat_messages(
|
|||
pass
|
||||
else:
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
message_type = get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
|
|
@ -240,7 +243,8 @@ async def convert_response_text_to_chat_response_format(
|
|||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def get_message_type_by_role(role: str):
|
||||
def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
|
||||
"""Get the appropriate OpenAI message parameter type for a given role."""
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
|
|
@ -307,3 +311,52 @@ def is_function_tool_call(
|
|||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None:
|
||||
"""Run multiple shields against messages and raise SafetyException for violations."""
|
||||
if not shield_ids or not messages:
|
||||
return
|
||||
for shield_id in shield_ids:
|
||||
response = await safety_api.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params={},
|
||||
)
|
||||
if response.violation and response.violation.violation_level.name == "ERROR":
|
||||
from ..safety import SafetyException
|
||||
|
||||
raise SafetyException(response.violation)
|
||||
|
||||
|
||||
def extract_shield_ids(shields: list | None) -> list[str]:
|
||||
"""Extract shield IDs from shields parameter, handling both string IDs and ResponseShieldSpec objects."""
|
||||
if not shields:
|
||||
return []
|
||||
|
||||
shield_ids = []
|
||||
for shield in shields:
|
||||
if isinstance(shield, str):
|
||||
shield_ids.append(shield)
|
||||
elif isinstance(shield, ResponseShieldSpec):
|
||||
shield_ids.append(shield.type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported shield type: {type(shield)}")
|
||||
|
||||
return shield_ids
|
||||
|
||||
|
||||
def extract_text_content(content: str | list | None) -> str | None:
|
||||
"""Extract text content from OpenAI message content (string or complex structure)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Handle complex content - extract text parts only
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if hasattr(part, "text"):
|
||||
text_parts.append(part.text)
|
||||
elif hasattr(part, "type") and part.type == "refusal":
|
||||
# Skip refusal parts - don't validate them again
|
||||
continue
|
||||
return " ".join(text_parts) if text_parts else None
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -247,12 +247,17 @@ class LlamaGuardShield:
|
|||
self.safety_categories = safety_categories
|
||||
|
||||
def check_unsafe_response(self, response: str) -> str | None:
|
||||
# Check for "unsafe\n<code>" format
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
# extracts the unsafe code
|
||||
extracted = match.group(1)
|
||||
return extracted
|
||||
|
||||
# Check for direct category code format (e.g., "S1", "S2", etc.)
|
||||
if re.match(r"^S\d+$", response):
|
||||
return response
|
||||
|
||||
return None
|
||||
|
||||
def get_safety_categories(self) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ explicit = true
|
|||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cpu" }]
|
||||
torchvision = [{ index = "pytorch-cpu" }]
|
||||
llama-stack-client = { path = "../llama-stack-client-python" }
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
|
|
|||
|
|
@ -297,3 +297,130 @@ def test_function_call_output_response_with_none_arguments(openai_client, client
|
|||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].arguments == "{}"
|
||||
_ = response.output[0].call_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_input_safety_shields(compat_client, text_model_id, stream):
|
||||
"""Test that input safety shields block inappropriate content."""
|
||||
if not isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
||||
|
||||
# Test with shield that should block inappropriate content
|
||||
# Note: This test assumes shields are configured in the test environment
|
||||
inappropriate_message = "How can I hurt someone?"
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": inappropriate_message,
|
||||
}
|
||||
],
|
||||
stream=stream,
|
||||
extra_body={"shields": ["llama-guard"]}, # Shields passed via extra_body
|
||||
)
|
||||
|
||||
if stream:
|
||||
# Input violations cause immediate refusal - only one completion chunk
|
||||
final_response = None
|
||||
for chunk in response:
|
||||
if chunk.type == "response.completed":
|
||||
final_response = chunk.response
|
||||
break
|
||||
|
||||
assert final_response is not None
|
||||
response_to_check = final_response
|
||||
else:
|
||||
response_to_check = response
|
||||
|
||||
# Check if the response contains a refusal (indicating shield blocked content)
|
||||
assert response_to_check.output[0].type == "message"
|
||||
message = response_to_check.output[0]
|
||||
|
||||
# Shield should ALWAYS block inappropriate input - expect refusal response
|
||||
assert len(message.content) > 0, "Content should not be empty"
|
||||
content_item = message.content[0]
|
||||
assert content_item.type == "refusal", f"Expected refusal content, but got: {content_item.type}"
|
||||
assert len(content_item.refusal.strip()) > 0, "Refusal message should not be empty"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_output_safety_shields(compat_client, text_model_id, stream):
|
||||
"""Test that output safety shields can block generated content."""
|
||||
if not isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
||||
|
||||
# Use an ambiguous prompt that could potentially generate content requiring shield evaluation
|
||||
# This prompt: "Describe self-defense techniques."
|
||||
# Could result in:
|
||||
# 1. Educational safety content → shield allows output → normal response
|
||||
# 2. Violent techniques content → shield blocks output → refusal response
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Describe violent content.",
|
||||
}
|
||||
],
|
||||
stream=stream,
|
||||
extra_body={"shields": ["llama-guard"]}, # Output shield validation
|
||||
)
|
||||
|
||||
if stream:
|
||||
# For streaming, violations cause stream to end with refusal completion
|
||||
last_chunk = None
|
||||
for chunk in response:
|
||||
last_chunk = chunk
|
||||
|
||||
assert last_chunk.type == "response.completed", f"Expected final chunk to be completion, got {last_chunk.type}"
|
||||
response_to_check = last_chunk.response
|
||||
else:
|
||||
response_to_check = response
|
||||
# Verify we get a proper response (this test mainly verifies the shield integration works)
|
||||
assert response_to_check.output[0].type == "message"
|
||||
message = response_to_check.output[0]
|
||||
|
||||
assert len(message.content) > 0, "Message should have content"
|
||||
content_item = message.content[0]
|
||||
assert content_item.type == "refusal", f"Content type should be 'refusal', got {content_item.type}"
|
||||
|
||||
|
||||
def test_shields_with_tools(compat_client, text_model_id):
|
||||
"""Test that shields work correctly when tools are present."""
|
||||
if not isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
||||
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like? Please help me in a safe and appropriate way.",
|
||||
}
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city to get the weather for"},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
extra_body={"shields": ["llama-guard"]},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify response completes successfully with tools and shields
|
||||
assert response.id is not None
|
||||
assert len(response.output) > 0
|
||||
|
||||
# Response should be either a function call or a message
|
||||
output_type = response.output[0].type
|
||||
assert output_type in ["function_call", "message"]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
|
|
@ -38,7 +39,9 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIUserMessageParam,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||
|
|
@ -90,6 +93,12 @@ def mock_conversations_api():
|
|||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
safety_api = AsyncMock()
|
||||
return safety_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_responses_impl(
|
||||
mock_inference_api,
|
||||
|
|
@ -98,6 +107,7 @@ def openai_responses_impl(
|
|||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_conversations_api,
|
||||
mock_safety_api,
|
||||
):
|
||||
return OpenAIResponsesImpl(
|
||||
inference_api=mock_inference_api,
|
||||
|
|
@ -106,6 +116,7 @@ def openai_responses_impl(
|
|||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
safety_api=mock_safety_api,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1082,3 +1093,52 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
|
|||
model=model,
|
||||
text=OpenAIResponseText(format={"type": "invalid"}),
|
||||
)
|
||||
|
||||
|
||||
async def test_check_input_safety_no_violation(openai_responses_impl):
|
||||
"""Test input shield validation with no violations."""
|
||||
messages = [UserMessage(content="Hello world")]
|
||||
shield_ids = ["llama-guard"]
|
||||
|
||||
# Mock successful shield validation (no violation)
|
||||
mock_response = AsyncMock()
|
||||
mock_response.violation = None
|
||||
openai_responses_impl.safety_api.run_shield.return_value = mock_response
|
||||
|
||||
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
|
||||
|
||||
assert result is None
|
||||
openai_responses_impl.safety_api.run_shield.assert_called_once_with(
|
||||
shield_id="llama-guard", messages=messages, params={}
|
||||
)
|
||||
|
||||
|
||||
async def test_check_input_safety_with_violation(openai_responses_impl):
|
||||
"""Test input shield validation with safety violation."""
|
||||
messages = [UserMessage(content="Harmful content")]
|
||||
shield_ids = ["llama-guard"]
|
||||
|
||||
# Mock shield violation
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
|
||||
)
|
||||
mock_response = AsyncMock()
|
||||
mock_response.violation = violation
|
||||
openai_responses_impl.safety_api.run_shield.return_value = mock_response
|
||||
|
||||
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
|
||||
|
||||
assert isinstance(result, OpenAIResponseContentPartRefusal)
|
||||
assert result.refusal == "Content violates safety guidelines"
|
||||
assert result.type == "refusal"
|
||||
|
||||
|
||||
async def test_check_input_safety_empty_inputs(openai_responses_impl):
|
||||
"""Test input shield validation with empty inputs."""
|
||||
# Test empty shield_ids
|
||||
result = await openai_responses_impl._check_input_safety([UserMessage(content="test")], [])
|
||||
assert result is None
|
||||
|
||||
# Test empty messages
|
||||
result = await openai_responses_impl._check_input_safety([], ["llama-guard"])
|
||||
assert result is None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,170 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseShieldSpec
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
extract_shield_ids,
|
||||
extract_text_content,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
"""Create mock APIs for testing."""
|
||||
return {
|
||||
"inference_api": AsyncMock(),
|
||||
"tool_groups_api": AsyncMock(),
|
||||
"tool_runtime_api": AsyncMock(),
|
||||
"responses_store": AsyncMock(),
|
||||
"vector_io_api": AsyncMock(),
|
||||
"conversations_api": AsyncMock(),
|
||||
"safety_api": AsyncMock(),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def responses_impl(mock_apis):
|
||||
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
|
||||
return OpenAIResponsesImpl(**mock_apis)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Shield ID Extraction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_extract_shield_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string shield IDs."""
|
||||
shields = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_shield_ids(shields)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseShieldSpec objects."""
|
||||
shields = [
|
||||
ResponseShieldSpec(type="llama-guard"),
|
||||
ResponseShieldSpec(type="content-filter"),
|
||||
]
|
||||
result = extract_shield_ids(shields)
|
||||
assert result == ["llama-guard", "content-filter"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_mixed_formats(responses_impl):
|
||||
"""Test extraction from mixed string and object formats."""
|
||||
shields = [
|
||||
"llama-guard",
|
||||
ResponseShieldSpec(type="content-filter"),
|
||||
"nsfw-detector",
|
||||
]
|
||||
result = extract_shield_ids(shields)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_none_input(responses_impl):
|
||||
"""Test extraction with None input."""
|
||||
result = extract_shield_ids(None)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_shield_ids_empty_list(responses_impl):
|
||||
"""Test extraction with empty list."""
|
||||
result = extract_shield_ids([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_shield_ids_unknown_format(responses_impl):
|
||||
"""Test extraction with unknown shield format raises ValueError."""
|
||||
# Create an object that's neither string nor ResponseShieldSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec
|
||||
shields = ["valid-shield", unknown_object, "another-shield"]
|
||||
with pytest.raises(ValueError, match="Unsupported shield type"):
|
||||
extract_shield_ids(shields)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Text Content Extraction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_extract_text_content_string(responses_impl):
|
||||
"""Test extraction from simple string content."""
|
||||
content = "Hello world"
|
||||
result = extract_text_content(content)
|
||||
assert result == "Hello world"
|
||||
|
||||
|
||||
def test_extract_text_content_list_with_text(responses_impl):
|
||||
"""Test extraction from list content with text parts."""
|
||||
content = [
|
||||
MagicMock(text="Hello "),
|
||||
MagicMock(text="world"),
|
||||
]
|
||||
result = extract_text_content(content)
|
||||
assert result == "Hello world"
|
||||
|
||||
|
||||
def test_extract_text_content_list_with_refusal(responses_impl):
|
||||
"""Test extraction skips refusal parts."""
|
||||
# Create text parts
|
||||
text_part1 = MagicMock()
|
||||
text_part1.text = "Hello"
|
||||
|
||||
text_part2 = MagicMock()
|
||||
text_part2.text = "world"
|
||||
|
||||
# Create refusal part (no text attribute)
|
||||
refusal_part = MagicMock()
|
||||
refusal_part.type = "refusal"
|
||||
refusal_part.refusal = "Blocked"
|
||||
del refusal_part.text # Remove text attribute
|
||||
|
||||
content = [text_part1, refusal_part, text_part2]
|
||||
result = extract_text_content(content)
|
||||
assert result == "Hello world"
|
||||
|
||||
|
||||
def test_extract_text_content_empty_list(responses_impl):
|
||||
"""Test extraction from empty list returns None."""
|
||||
content = []
|
||||
result = extract_text_content(content)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_extract_text_content_no_text_parts(responses_impl):
|
||||
"""Test extraction with no text parts returns None."""
|
||||
# Create image part (no text attribute)
|
||||
image_part = MagicMock()
|
||||
image_part.type = "image"
|
||||
image_part.image_url = "http://example.com"
|
||||
|
||||
# Create refusal part (no text attribute)
|
||||
refusal_part = MagicMock()
|
||||
refusal_part.type = "refusal"
|
||||
refusal_part.refusal = "Blocked"
|
||||
|
||||
# Explicitly remove text attributes to simulate non-text parts
|
||||
if hasattr(image_part, "text"):
|
||||
delattr(image_part, "text")
|
||||
if hasattr(refusal_part, "text"):
|
||||
delattr(refusal_part, "text")
|
||||
|
||||
content = [image_part, refusal_part]
|
||||
result = extract_text_content(content)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_extract_text_content_none_input(responses_impl):
|
||||
"""Test extraction with None input returns None."""
|
||||
result = extract_text_content(None)
|
||||
assert result is None
|
||||
51
uv.lock
generated
51
uv.lock
generated
|
|
@ -1897,8 +1897,8 @@ requires-dist = [
|
|||
{ name = "httpx" },
|
||||
{ name = "jinja2", specifier = ">=3.1.6" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "llama-stack-client", specifier = ">=0.2.23" },
|
||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" },
|
||||
{ name = "llama-stack-client", directory = "../llama-stack-client-python" },
|
||||
{ name = "llama-stack-client", marker = "extra == 'ui'", directory = "../llama-stack-client-python" },
|
||||
{ name = "openai", specifier = ">=1.107" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
||||
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
||||
|
|
@ -2004,8 +2004,8 @@ unit = [
|
|||
|
||||
[[package]]
|
||||
name = "llama-stack-client"
|
||||
version = "0.2.23"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
version = "0.3.0a3"
|
||||
source = { directory = "../llama-stack-client-python" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "click" },
|
||||
|
|
@ -2023,10 +2023,47 @@ dependencies = [
|
|||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9f/8f/306d5fcf2f97b3a6251219b03c194836a2ff4e0fcc8146c9970e50a72cd3/llama_stack_client-0.2.23.tar.gz", hash = "sha256:68f34e8ac8eea6a73ed9d4977d849992b2d8bd835804d770a11843431cd5bf74", size = 322288, upload-time = "2025-09-26T21:11:08.342Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/75/3eb58e092a681804013dbec7b7f549d18f55acf6fd6e6b27de7e249766d8/llama_stack_client-0.2.23-py3-none-any.whl", hash = "sha256:eee42c74eee8f218f9455e5a06d5d4be43f8a8c82a7937ef51ce367f916df847", size = 379809, upload-time = "2025-09-26T21:11:06.856Z" },
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiohttp", marker = "extra == 'aiohttp'" },
|
||||
{ name = "anyio", specifier = ">=3.5.0,<5" },
|
||||
{ name = "click" },
|
||||
{ name = "distro", specifier = ">=1.7.0,<2" },
|
||||
{ name = "fire" },
|
||||
{ name = "httpx", specifier = ">=0.23.0,<1" },
|
||||
{ name = "httpx-aiohttp", marker = "extra == 'aiohttp'", specifier = ">=0.1.8" },
|
||||
{ name = "pandas" },
|
||||
{ name = "prompt-toolkit" },
|
||||
{ name = "pyaml" },
|
||||
{ name = "pydantic", specifier = ">=1.9.0,<3" },
|
||||
{ name = "requests" },
|
||||
{ name = "rich" },
|
||||
{ name = "sniffio" },
|
||||
{ name = "termcolor" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions", specifier = ">=4.7,<5" },
|
||||
]
|
||||
provides-extras = ["aiohttp"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "black" },
|
||||
{ name = "dirty-equals", specifier = ">=0.6.0" },
|
||||
{ name = "importlib-metadata", specifier = ">=6.7.0" },
|
||||
{ name = "mypy" },
|
||||
{ name = "pre-commit" },
|
||||
{ name = "pyright", specifier = "==1.1.399" },
|
||||
{ name = "pytest", specifier = ">=7.1.1" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-xdist", specifier = ">=3.6.1" },
|
||||
{ name = "respx" },
|
||||
{ name = "rich", specifier = ">=13.7.1" },
|
||||
{ name = "ruff" },
|
||||
{ name = "time-machine" },
|
||||
]
|
||||
pydantic-v1 = [{ name = "pydantic", specifier = ">=1.9.0,<2" }]
|
||||
pydantic-v2 = [{ name = "pydantic", specifier = ">=2,<3" }]
|
||||
|
||||
[[package]]
|
||||
name = "locust"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue