feat: Add responses and safety impl extra_body (#3781)
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 1s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 6s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 9s
Test External API and Providers / test-external (venv) (push) Failing after 8s
Test Llama Stack Build / build (push) Failing after 7s
Unit Tests / unit-tests (3.12) (push) Failing after 9s
API Conformance Tests / check-schema-compatibility (push) Successful in 19s
UI Tests / ui-tests (22) (push) Successful in 37s
Pre-commit / pre-commit (push) Successful in 1m33s

# What does this PR do?

Have closed the previous PR due to merge conflicts with multiple PRs
Addressed all comments from
https://github.com/llamastack/llama-stack/pull/3768 (sorry for carrying
over to this one)


## Test Plan
Added UTs and integration tests
This commit is contained in:
slekkala1 2025-10-15 15:01:37 -07:00 committed by GitHub
parent 8e7e0ddfec
commit 99141c29b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
244 changed files with 36829 additions and 235 deletions

View file

@ -43,17 +43,17 @@ from .openai_responses import (
@json_schema_type
class ResponseShieldSpec(BaseModel):
"""Specification for a shield to apply during response generation.
class ResponseGuardrailSpec(BaseModel):
"""Specification for a guardrail to apply during response generation.
:param type: The type/identifier of the shield.
:param type: The type/identifier of the guardrail.
"""
type: str
# TODO: more fields to be added for shield configuration
# TODO: more fields to be added for guardrail configuration
ResponseShield = str | ResponseShieldSpec
ResponseGuardrail = str | ResponseGuardrailSpec
class Attachment(BaseModel):
@ -820,10 +820,10 @@ class Agents(Protocol):
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
shields: Annotated[
list[ResponseShield] | None,
guardrails: Annotated[
list[ResponseGuardrail] | None,
ExtraBodyField(
"List of shields to apply during response generation. Shields provide safety and content moderation."
"List of guardrails to apply during response generation. Guardrails provide safety and content moderation."
),
] = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
@ -834,7 +834,7 @@ class Agents(Protocol):
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param conversation: (Optional) The ID of a conversation to add the response to. Must begin with 'conv_'. Input and output messages will be automatically added to the conversation.
:param include: (Optional) Additional fields to include in the response.
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
:returns: An OpenAIResponseObject.
"""
...

View file

@ -131,8 +131,20 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
"""Refusal content within a streamed response part.
:param type: Content part type identifier, always "refusal"
:param refusal: Refusal text supplied by the model
"""
type: Literal["refusal"] = "refusal"
refusal: str
OpenAIResponseOutputMessageContent = Annotated[
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
@ -878,18 +890,6 @@ class OpenAIResponseContentPartOutputText(BaseModel):
logprobs: list[dict[str, Any]] | None = None
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
"""Refusal content within a streamed response part.
:param type: Content part type identifier, always "refusal"
:param refusal: Refusal text supplied by the model
"""
type: Literal["refusal"] = "refusal"
refusal: str
@json_schema_type
class OpenAIResponseContentPartReasoningText(BaseModel):
"""Reasoning text emitted as part of a streamed response.

View file

@ -28,6 +28,7 @@ from llama_stack.apis.agents import (
Session,
Turn,
)
from llama_stack.apis.agents.agents import ResponseGuardrail
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
@ -91,6 +92,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store,
vector_io_api=self.vector_io_api,
safety_api=self.safety_api,
conversations_api=self.conversations_api,
)
@ -337,7 +339,7 @@ class MetaReferenceAgentsImpl(Agents):
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
shields: list | None = None,
guardrails: list[ResponseGuardrail] | None = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input,
@ -352,7 +354,7 @@ class MetaReferenceAgentsImpl(Agents):
tools,
include,
max_infer_iters,
shields,
guardrails,
)
async def list_openai_responses(

View file

@ -11,6 +11,7 @@ from collections.abc import AsyncIterator
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
@ -34,6 +35,7 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
OpenAISystemMessageParam,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
@ -48,6 +50,7 @@ from .types import ChatCompletionContext, ToolContext
from .utils import (
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
extract_guardrail_ids,
)
logger = get_logger(name=__name__, category="openai_responses")
@ -66,6 +69,7 @@ class OpenAIResponsesImpl:
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
safety_api: Safety,
conversations_api: Conversations,
):
self.inference_api = inference_api
@ -73,6 +77,7 @@ class OpenAIResponsesImpl:
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
self.safety_api = safety_api
self.conversations_api = conversations_api
self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api,
@ -244,14 +249,12 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
shields: list | None = None,
guardrails: list[ResponseGuardrailSpec] | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
# Shields parameter received via extra_body - not yet implemented
if shields is not None:
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
if conversation is not None:
if previous_response_id is not None:
@ -273,6 +276,7 @@ class OpenAIResponsesImpl:
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
guardrail_ids=guardrail_ids,
)
if stream:
@ -318,6 +322,7 @@ class OpenAIResponsesImpl:
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
all_input, messages, tool_context = await self._process_input_with_previous_response(
@ -352,6 +357,8 @@ class OpenAIResponsesImpl:
text=text,
max_infer_iters=max_infer_iters,
tool_executor=self.tool_executor,
safety_api=self.safety_api,
guardrail_ids=guardrail_ids,
)
# Stream the response

View file

@ -66,10 +66,15 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.telemetry import tracing
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
from .utils import (
convert_chat_choice_to_response_message,
is_function_tool_call,
run_guardrails,
)
logger = get_logger(name=__name__, category="agents::meta_reference")
@ -105,6 +110,8 @@ class StreamingResponseOrchestrator:
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
safety_api,
guardrail_ids: list[str] | None = None,
):
self.inference_api = inference_api
self.ctx = ctx
@ -113,6 +120,8 @@ class StreamingResponseOrchestrator:
self.text = text
self.max_infer_iters = max_infer_iters
self.tool_executor = tool_executor
self.safety_api = safety_api
self.guardrail_ids = guardrail_ids or []
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
@ -122,6 +131,23 @@ class StreamingResponseOrchestrator:
self.citation_files: dict[str, str] = {}
# Track accumulated usage across all inference calls
self.accumulated_usage: OpenAIResponseUsage | None = None
# Track if we've sent a refusal response
self.violation_detected = False
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
# Create a completed refusal response
refusal_response = OpenAIResponseObject(
id=self.response_id,
created_at=self.created_at,
model=self.ctx.model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
cloned: list[OpenAIResponseOutput] = []
@ -166,6 +192,15 @@ class StreamingResponseOrchestrator:
sequence_number=self.sequence_number,
)
# Input safety validation - check messages before processing
if self.guardrail_ids:
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
input_violation_message = await run_guardrails(self.safety_api, combined_text, self.guardrail_ids)
if input_violation_message:
logger.info(f"Input guardrail violation: {input_violation_message}")
yield await self._create_refusal_response(input_violation_message)
return
async for stream_event in self._process_tools(output_messages):
yield stream_event
@ -201,6 +236,11 @@ class StreamingResponseOrchestrator:
completion_result_data = stream_event_or_result
else:
yield stream_event_or_result
# If violation detected, skip the rest of processing since we already sent refusal
if self.violation_detected:
return
if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data")
last_completion_result = completion_result_data
@ -525,6 +565,9 @@ class StreamingResponseOrchestrator:
# Accumulate usage from chunks (typically in final chunk with stream_options)
self._accumulate_chunk_usage(chunk)
# Track deltas for this specific chunk for guardrail validation
chunk_events: list[OpenAIResponseObjectStream] = []
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
@ -560,13 +603,19 @@ class StreamingResponseOrchestrator:
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
text_delta_event = OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=content_index,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
# Buffer text delta events for guardrail check
if self.guardrail_ids:
chunk_events.append(text_delta_event)
else:
yield text_delta_event
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
@ -582,7 +631,11 @@ class StreamingResponseOrchestrator:
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
# Buffer reasoning events for guardrail check
if self.guardrail_ids:
chunk_events.append(event)
else:
yield event
reasoning_part_emitted = True
reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content)
@ -664,6 +717,22 @@ class StreamingResponseOrchestrator:
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Output Safety Validation for this chunk
if self.guardrail_ids:
# Check guardrails on accumulated text so far
accumulated_text = "".join(chat_response_content)
violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
if violation_message:
logger.info(f"Output guardrail violation: {violation_message}")
chunk_events.clear()
yield await self._create_refusal_response(violation_message)
self.violation_detected = True
return
else:
# No violation detected, emit all content events for this chunk
for event in chunk_events:
yield event
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call = chat_response_tool_calls[tool_call_index]

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import re
import uuid
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput,
@ -45,6 +47,7 @@ from llama_stack.apis.inference import (
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.safety import Safety
async def convert_chat_choice_to_response_message(
@ -240,7 +243,8 @@ async def convert_response_text_to_chat_response_format(
raise ValueError(f"Unsupported text format: {text.format}")
async def get_message_type_by_role(role: str):
async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
"""Get the appropriate OpenAI message parameter type for a given role."""
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
@ -307,3 +311,55 @@ def is_function_tool_call(
if t.type == "function" and t.name == tool_call.function.name:
return True
return False
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
"""Run guardrails against messages and return violation message if blocked."""
if not messages:
return None
# Look up shields to get their provider_resource_id (actual model ID)
model_ids = []
shields_list = await safety_api.routing_table.list_shields()
for guardrail_id in guardrail_ids:
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
if matching_shields:
model_id = matching_shields[0].provider_resource_id
model_ids.append(model_id)
else:
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
responses = await asyncio.gather(*guardrail_tasks)
for response in responses:
for result in response.results:
if result.flagged:
message = result.user_message or "Content blocked by safety guardrails"
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
if flagged_categories:
message += f" (flagged for: {', '.join(flagged_categories)})"
if violation_type:
message += f" (violation type: {', '.join(violation_type)})"
return message
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
if not guardrails:
return []
guardrail_ids = []
for guardrail in guardrails:
if isinstance(guardrail, str):
guardrail_ids.append(guardrail)
elif isinstance(guardrail, ResponseGuardrailSpec):
guardrail_ids.append(guardrail.type)
else:
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
return guardrail_ids