feat: Add responses and safety impl with extra body

This commit is contained in:
Swapna Lekkala 2025-10-10 07:12:51 -07:00
parent 6954fe2274
commit 9152efa1a9
18 changed files with 833 additions and 164 deletions

View file

@ -92,6 +92,7 @@ class MetaReferenceAgentsImpl(Agents):
responses_store=self.responses_store,
vector_io_api=self.vector_io_api,
conversations_api=self.conversations_api,
safety_api=self.safety_api,
)
async def create_agent(

View file

@ -15,12 +15,15 @@ from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseContentPartRefusal,
OpenAIResponseInput,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
@ -31,9 +34,11 @@ from llama_stack.apis.conversations import Conversations
from llama_stack.apis.conversations.conversations import ConversationItem
from llama_stack.apis.inference import (
Inference,
Message,
OpenAIMessageParam,
OpenAISystemMessageParam,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
@ -42,12 +47,15 @@ from llama_stack.providers.utils.responses.responses_store import (
_OpenAIResponseObjectWithInputAndMessages,
)
from ..safety import SafetyException
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
from .types import ChatCompletionContext, ToolContext
from .utils import (
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
extract_shield_ids,
run_multiple_shields,
)
logger = get_logger(name=__name__, category="openai_responses")
@ -67,6 +75,7 @@ class OpenAIResponsesImpl:
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
conversations_api: Conversations,
safety_api: Safety,
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
@ -74,6 +83,7 @@ class OpenAIResponsesImpl:
self.responses_store = responses_store
self.vector_io_api = vector_io_api
self.conversations_api = conversations_api
self.safety_api = safety_api
self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api,
@ -225,9 +235,7 @@ class OpenAIResponsesImpl:
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
# Shields parameter received via extra_body - not yet implemented
if shields is not None:
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
shield_ids = extract_shield_ids(shields) if shields else []
if conversation is not None and previous_response_id is not None:
raise ValueError(
@ -255,6 +263,7 @@ class OpenAIResponsesImpl:
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
shield_ids=shield_ids,
)
if stream:
@ -288,6 +297,42 @@ class OpenAIResponsesImpl:
raise ValueError("The response stream never reached a terminal state")
return final_response
async def _check_input_safety(
self, messages: list[Message], shield_ids: list[str]
) -> OpenAIResponseContentPartRefusal | None:
"""Validate input messages against shields. Returns refusal content if violation found."""
try:
await run_multiple_shields(self.safety_api, messages, shield_ids)
except SafetyException as e:
logger.info(f"Input shield violation: {e.violation.user_message}")
return OpenAIResponseContentPartRefusal(
refusal=e.violation.user_message or "Content blocked by safety shields"
)
async def _create_refusal_response_events(
self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Create and yield refusal response events following the established streaming pattern."""
# Create initial response and yield created event
initial_response = OpenAIResponseObject(
id=response_id,
created_at=created_at,
model=model,
status="in_progress",
output=[],
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# Create completed refusal response using OpenAIResponseContentPartRefusal
refusal_response = OpenAIResponseObject(
id=response_id,
created_at=created_at,
model=model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
@ -301,6 +346,7 @@ class OpenAIResponsesImpl:
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
shield_ids: list[str] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
all_input, messages, tool_context = await self._process_input_with_previous_response(
@ -333,8 +379,11 @@ class OpenAIResponsesImpl:
text=text,
max_infer_iters=max_infer_iters,
tool_executor=self.tool_executor,
safety_api=self.safety_api,
shield_ids=shield_ids,
)
# Output safety validation hook - delegated to streaming orchestrator for real-time validation
# Stream the response
final_response = None
failed_response = None

View file

@ -13,10 +13,12 @@ from llama_stack.apis.agents.openai_responses import (
ApprovalFilter,
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseContentPartRefusal,
OpenAIResponseError,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
@ -45,6 +47,7 @@ from llama_stack.apis.agents.openai_responses import (
WebSearchToolTypes,
)
from llama_stack.apis.inference import (
CompletionMessage,
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
@ -52,12 +55,18 @@ from llama_stack.apis.inference import (
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
StopReason,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
from ..safety import SafetyException
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
from .utils import (
convert_chat_choice_to_response_message,
is_function_tool_call,
run_multiple_shields,
)
logger = get_logger(name=__name__, category="agents::meta_reference")
@ -93,6 +102,8 @@ class StreamingResponseOrchestrator:
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
safety_api,
shield_ids: list[str] | None = None,
):
self.inference_api = inference_api
self.ctx = ctx
@ -101,6 +112,8 @@ class StreamingResponseOrchestrator:
self.text = text
self.max_infer_iters = max_infer_iters
self.tool_executor = tool_executor
self.safety_api = safety_api
self.shield_ids = shield_ids or []
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
@ -110,6 +123,61 @@ class StreamingResponseOrchestrator:
self.citation_files: dict[str, str] = {}
# Track accumulated usage across all inference calls
self.accumulated_usage: OpenAIResponseUsage | None = None
# Track if we've sent a refusal response
self.violation_detected = False
async def _check_input_safety(self, messages: list[OpenAIMessageParam]) -> OpenAIResponseContentPartRefusal | None:
"""Validate input messages against shields. Returns refusal content if violation found."""
try:
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
except SafetyException as e:
logger.info(f"Input shield violation: {e.violation.user_message}")
return OpenAIResponseContentPartRefusal(
refusal=e.violation.user_message or "Content blocked by safety shields"
)
async def _create_input_refusal_response_events(
self, refusal_content: OpenAIResponseContentPartRefusal
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Create refusal response events for input safety violations."""
# Create the refusal content part explicitly with the correct structure
refusal_part = OpenAIResponseContentPartRefusal(refusal=refusal_content.refusal, type="refusal")
refusal_response = OpenAIResponseObject(
id=self.response_id,
created_at=self.created_at,
model=self.ctx.model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_part], type="message")],
)
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
"""Check accumulated streaming text content against shields. Returns violation message if blocked."""
if not self.shield_ids or not accumulated_text:
return None
messages = [CompletionMessage(content=accumulated_text, stop_reason=StopReason.end_of_turn)]
try:
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
except SafetyException as e:
logger.info(f"Output shield violation: {e.violation.user_message}")
return e.violation.user_message or "Generated content blocked by safety shields"
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
# Create a completed refusal response
refusal_response = OpenAIResponseObject(
id=self.response_id,
created_at=self.created_at,
model=self.ctx.model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
cloned: list[OpenAIResponseOutput] = []
@ -154,6 +222,15 @@ class StreamingResponseOrchestrator:
sequence_number=self.sequence_number,
)
# Input safety validation - check messages before processing
if self.shield_ids:
input_refusal = await self._check_input_safety(self.ctx.messages)
if input_refusal:
# Return refusal response immediately
async for refusal_event in self._create_input_refusal_response_events(input_refusal):
yield refusal_event
return
async for stream_event in self._process_tools(output_messages):
yield stream_event
@ -187,6 +264,10 @@ class StreamingResponseOrchestrator:
completion_result_data = stream_event_or_result
else:
yield stream_event_or_result
# If violation detected, skip the rest of processing since we already sent refusal
if self.violation_detected:
return
if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data")
last_completion_result = completion_result_data
@ -475,6 +556,15 @@ class StreamingResponseOrchestrator:
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Safety check after processing all chunks
if chat_response_content:
accumulated_text = "".join(chat_response_content)
violation_message = await self._check_output_stream_chunk_safety(accumulated_text)
if violation_message:
yield await self._create_refusal_response(violation_message)
self.violation_detected = True
return
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call = chat_response_tool_calls[tool_call_index]

View file

@ -7,6 +7,7 @@
import re
import uuid
from llama_stack.apis.agents.agents import ResponseShieldSpec
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput,
@ -26,6 +27,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText,
)
from llama_stack.apis.inference import (
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
@ -45,6 +47,7 @@ from llama_stack.apis.inference import (
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.safety import Safety
async def convert_chat_choice_to_response_message(
@ -171,7 +174,7 @@ async def convert_response_input_to_chat_messages(
pass
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
message_type = get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
@ -240,7 +243,8 @@ async def convert_response_text_to_chat_response_format(
raise ValueError(f"Unsupported text format: {text.format}")
async def get_message_type_by_role(role: str):
def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
"""Get the appropriate OpenAI message parameter type for a given role."""
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
@ -307,3 +311,52 @@ def is_function_tool_call(
if t.type == "function" and t.name == tool_call.function.name:
return True
return False
async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None:
"""Run multiple shields against messages and raise SafetyException for violations."""
if not shield_ids or not messages:
return
for shield_id in shield_ids:
response = await safety_api.run_shield(
shield_id=shield_id,
messages=messages,
params={},
)
if response.violation and response.violation.violation_level.name == "ERROR":
from ..safety import SafetyException
raise SafetyException(response.violation)
def extract_shield_ids(shields: list | None) -> list[str]:
"""Extract shield IDs from shields parameter, handling both string IDs and ResponseShieldSpec objects."""
if not shields:
return []
shield_ids = []
for shield in shields:
if isinstance(shield, str):
shield_ids.append(shield)
elif isinstance(shield, ResponseShieldSpec):
shield_ids.append(shield.type)
else:
raise ValueError(f"Unsupported shield type: {type(shield)}")
return shield_ids
def extract_text_content(content: str | list | None) -> str | None:
"""Extract text content from OpenAI message content (string or complex structure)."""
if isinstance(content, str):
return content
elif isinstance(content, list):
# Handle complex content - extract text parts only
text_parts = []
for part in content:
if hasattr(part, "text"):
text_parts.append(part.text)
elif hasattr(part, "type") and part.type == "refusal":
# Skip refusal parts - don't validate them again
continue
return " ".join(text_parts) if text_parts else None
return None

View file

@ -247,12 +247,17 @@ class LlamaGuardShield:
self.safety_categories = safety_categories
def check_unsafe_response(self, response: str) -> str | None:
# Check for "unsafe\n<code>" format
match = re.match(r"^unsafe\n(.*)$", response)
if match:
# extracts the unsafe code
extracted = match.group(1)
return extracted
# Check for direct category code format (e.g., "S1", "S2", etc.)
if re.match(r"^S\d+$", response):
return response
return None
def get_safety_categories(self) -> list[str]: