Merge branch 'responses-and-safety' into responses-and-safety

This commit is contained in:
slekkala1 2025-10-10 14:12:53 -07:00 committed by GitHub
commit 74cb26a021
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 91 additions and 1 deletions

View file

@ -91,6 +91,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api=self.tool_runtime_api, tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store, responses_store=self.responses_store,
vector_io_api=self.vector_io_api, vector_io_api=self.vector_io_api,
safety_api=self.safety_api,
conversations_api=self.conversations_api, conversations_api=self.conversations_api,
safety_api=self.safety_api, safety_api=self.safety_api,
) )

View file

@ -34,6 +34,7 @@ from llama_stack.apis.conversations import Conversations
from llama_stack.apis.conversations.conversations import ConversationItem from llama_stack.apis.conversations.conversations import ConversationItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
Message,
OpenAIMessageParam, OpenAIMessageParam,
OpenAISystemMessageParam, OpenAISystemMessageParam,
) )
@ -46,10 +47,12 @@ from llama_stack.providers.utils.responses.responses_store import (
_OpenAIResponseObjectWithInputAndMessages, _OpenAIResponseObjectWithInputAndMessages,
) )
from ..safety import SafetyException
from .streaming import StreamingResponseOrchestrator from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor from .tool_executor import ToolExecutor
from .types import ChatCompletionContext, ToolContext from .types import ChatCompletionContext, ToolContext
from .utils import ( from .utils import (
convert_openai_to_inference_messages,
convert_response_input_to_chat_messages, convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format, convert_response_text_to_chat_response_format,
extract_shield_ids, extract_shield_ids,
@ -71,6 +74,7 @@ class OpenAIResponsesImpl:
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore, responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO vector_io_api: VectorIO, # VectorIO
safety_api: Safety,
conversations_api: Conversations, conversations_api: Conversations,
safety_api: Safety, safety_api: Safety,
): ):
@ -79,6 +83,7 @@ class OpenAIResponsesImpl:
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store self.responses_store = responses_store
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api
self.safety_api = safety_api
self.conversations_api = conversations_api self.conversations_api = conversations_api
self.safety_api = safety_api self.safety_api = safety_api
self.tool_executor = ToolExecutor( self.tool_executor = ToolExecutor(
@ -339,6 +344,21 @@ class OpenAIResponsesImpl:
) )
await self._prepend_instructions(messages, instructions) await self._prepend_instructions(messages, instructions)
# Input safety validation hook - validates messages before streaming orchestrator starts
if shield_ids:
input_messages = convert_openai_to_inference_messages(messages)
input_refusal = await self._check_input_safety(input_messages, shield_ids)
if input_refusal:
# Return refusal response immediately
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
async for refusal_event in self._create_refusal_response_events(
input_refusal, response_id, created_at, model
):
yield refusal_event
return
# Structured outputs # Structured outputs
response_format = await convert_response_text_to_chat_response_format(text) response_format = await convert_response_text_to_chat_response_format(text)

View file

@ -467,6 +467,15 @@ class StreamingResponseOrchestrator:
for chunk_choice in chunk.choices: for chunk_choice in chunk.choices:
# Emit incremental text content as delta events # Emit incremental text content as delta events
if chunk_choice.delta.content: if chunk_choice.delta.content:
# Check output stream safety before yielding content
violation_message = await self._check_output_stream_safety(chunk_choice.delta.content)
if violation_message:
# Stop streaming and send refusal response
yield await self._create_refusal_response(violation_message)
self.violation_detected = True
# Return immediately - no further processing needed
return
# Emit content_part.added event for first text chunk # Emit content_part.added event for first text chunk
if not content_part_emitted: if not content_part_emitted:
content_part_emitted = True content_part_emitted = True

View file

@ -47,6 +47,9 @@ from llama_stack.apis.inference import (
OpenAISystemMessageParam, OpenAISystemMessageParam,
OpenAIToolMessageParam, OpenAIToolMessageParam,
OpenAIUserMessageParam, OpenAIUserMessageParam,
StopReason,
SystemMessage,
UserMessage,
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
@ -244,7 +247,7 @@ async def convert_response_text_to_chat_response_format(
raise ValueError(f"Unsupported text format: {text.format}") raise ValueError(f"Unsupported text format: {text.format}")
def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None: async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
"""Get the appropriate OpenAI message parameter type for a given role.""" """Get the appropriate OpenAI message parameter type for a given role."""
role_to_type = { role_to_type = {
"user": OpenAIUserMessageParam, "user": OpenAIUserMessageParam,

View file

@ -18,6 +18,7 @@ from openai.types.chat.chat_completion_chunk import (
from llama_stack.apis.agents import Order from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem, ListOpenAIResponseInputItem,
OpenAIResponseContentPartRefusal,
OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction, OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP, OpenAIResponseInputToolMCP,
@ -38,7 +39,9 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatJSONObject, OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema, OpenAIResponseFormatJSONSchema,
OpenAIUserMessageParam, OpenAIUserMessageParam,
UserMessage,
) )
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.core.datatypes import ResponsesStoreConfig
@ -84,6 +87,9 @@ def mock_vector_io_api():
@pytest.fixture @pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
return safety_api
def mock_conversations_api(): def mock_conversations_api():
"""Mock conversations API for testing.""" """Mock conversations API for testing."""
mock_api = AsyncMock() mock_api = AsyncMock()
@ -103,6 +109,7 @@ def openai_responses_impl(
mock_tool_runtime_api, mock_tool_runtime_api,
mock_responses_store, mock_responses_store,
mock_vector_io_api, mock_vector_io_api,
mock_safety_api,
mock_conversations_api, mock_conversations_api,
mock_safety_api, mock_safety_api,
): ):
@ -112,6 +119,7 @@ def openai_responses_impl(
tool_runtime_api=mock_tool_runtime_api, tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store, responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api, vector_io_api=mock_vector_io_api,
safety_api=mock_safety_api,
conversations_api=mock_conversations_api, conversations_api=mock_conversations_api,
safety_api=mock_safety_api, safety_api=mock_safety_api,
) )
@ -1090,3 +1098,52 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
model=model, model=model,
text=OpenAIResponseText(format={"type": "invalid"}), text=OpenAIResponseText(format={"type": "invalid"}),
) )
async def test_check_input_safety_no_violation(openai_responses_impl):
"""Test input shield validation with no violations."""
messages = [UserMessage(content="Hello world")]
shield_ids = ["llama-guard"]
# Mock successful shield validation (no violation)
mock_response = AsyncMock()
mock_response.violation = None
openai_responses_impl.safety_api.run_shield.return_value = mock_response
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
assert result is None
openai_responses_impl.safety_api.run_shield.assert_called_once_with(
shield_id="llama-guard", messages=messages, params={}
)
async def test_check_input_safety_with_violation(openai_responses_impl):
"""Test input shield validation with safety violation."""
messages = [UserMessage(content="Harmful content")]
shield_ids = ["llama-guard"]
# Mock shield violation
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
)
mock_response = AsyncMock()
mock_response.violation = violation
openai_responses_impl.safety_api.run_shield.return_value = mock_response
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
assert isinstance(result, OpenAIResponseContentPartRefusal)
assert result.refusal == "Content violates safety guidelines"
assert result.type == "refusal"
async def test_check_input_safety_empty_inputs(openai_responses_impl):
"""Test input shield validation with empty inputs."""
# Test empty shield_ids
result = await openai_responses_impl._check_input_safety([UserMessage(content="test")], [])
assert result is None
# Test empty messages
result = await openai_responses_impl._check_input_safety([], ["llama-guard"])
assert result is None