mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 15:27:16 +00:00
feat: Add responses and safety impl extra_body (#3781)
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 1s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 6s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 9s
Test External API and Providers / test-external (venv) (push) Failing after 8s
Test Llama Stack Build / build (push) Failing after 7s
Unit Tests / unit-tests (3.12) (push) Failing after 9s
API Conformance Tests / check-schema-compatibility (push) Successful in 19s
UI Tests / ui-tests (22) (push) Successful in 37s
Pre-commit / pre-commit (push) Successful in 1m33s
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 1s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 6s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 9s
Test External API and Providers / test-external (venv) (push) Failing after 8s
Test Llama Stack Build / build (push) Failing after 7s
Unit Tests / unit-tests (3.12) (push) Failing after 9s
API Conformance Tests / check-schema-compatibility (push) Successful in 19s
UI Tests / ui-tests (22) (push) Successful in 37s
Pre-commit / pre-commit (push) Successful in 1m33s
# What does this PR do? Have closed the previous PR due to merge conflicts with multiple PRs Addressed all comments from https://github.com/llamastack/llama-stack/pull/3768 (sorry for carrying over to this one) ## Test Plan Added UTs and integration tests
This commit is contained in:
parent
8e7e0ddfec
commit
99141c29b1
244 changed files with 36829 additions and 235 deletions
|
@ -28,6 +28,7 @@ from llama_stack.apis.agents import (
|
|||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrail
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
|
@ -91,6 +92,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_runtime_api=self.tool_runtime_api,
|
||||
responses_store=self.responses_store,
|
||||
vector_io_api=self.vector_io_api,
|
||||
safety_api=self.safety_api,
|
||||
conversations_api=self.conversations_api,
|
||||
)
|
||||
|
||||
|
@ -337,7 +339,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
guardrails: list[ResponseGuardrail] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
|
@ -352,7 +354,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools,
|
||||
include,
|
||||
max_infer_iters,
|
||||
shields,
|
||||
guardrails,
|
||||
)
|
||||
|
||||
async def list_openai_responses(
|
||||
|
|
|
@ -11,6 +11,7 @@ from collections.abc import AsyncIterator
|
|||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
|
@ -34,6 +35,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -48,6 +50,7 @@ from .types import ChatCompletionContext, ToolContext
|
|||
from .utils import (
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
extract_guardrail_ids,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
@ -66,6 +69,7 @@ class OpenAIResponsesImpl:
|
|||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
safety_api: Safety,
|
||||
conversations_api: Conversations,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
|
@ -73,6 +77,7 @@ class OpenAIResponsesImpl:
|
|||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.safety_api = safety_api
|
||||
self.conversations_api = conversations_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
|
@ -244,14 +249,12 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shields: list | None = None,
|
||||
guardrails: list[ResponseGuardrailSpec] | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
# Shields parameter received via extra_body - not yet implemented
|
||||
if shields is not None:
|
||||
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
||||
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
|
||||
|
||||
if conversation is not None:
|
||||
if previous_response_id is not None:
|
||||
|
@ -273,6 +276,7 @@ class OpenAIResponsesImpl:
|
|||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
guardrail_ids=guardrail_ids,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
@ -318,6 +322,7 @@ class OpenAIResponsesImpl:
|
|||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
|
@ -352,6 +357,8 @@ class OpenAIResponsesImpl:
|
|||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
safety_api=self.safety_api,
|
||||
guardrail_ids=guardrail_ids,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
|
|
|
@ -66,10 +66,15 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||
from .utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
is_function_tool_call,
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
@ -105,6 +110,8 @@ class StreamingResponseOrchestrator:
|
|||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
safety_api,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
|
@ -113,6 +120,8 @@ class StreamingResponseOrchestrator:
|
|||
self.text = text
|
||||
self.max_infer_iters = max_infer_iters
|
||||
self.tool_executor = tool_executor
|
||||
self.safety_api = safety_api
|
||||
self.guardrail_ids = guardrail_ids or []
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
||||
|
@ -122,6 +131,23 @@ class StreamingResponseOrchestrator:
|
|||
self.citation_files: dict[str, str] = {}
|
||||
# Track accumulated usage across all inference calls
|
||||
self.accumulated_usage: OpenAIResponseUsage | None = None
|
||||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
|
||||
|
||||
# Create a completed refusal response
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=self.response_id,
|
||||
created_at=self.created_at,
|
||||
model=self.ctx.model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
|
||||
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
||||
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
|
||||
cloned: list[OpenAIResponseOutput] = []
|
||||
|
@ -166,6 +192,15 @@ class StreamingResponseOrchestrator:
|
|||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Input safety validation - check messages before processing
|
||||
if self.guardrail_ids:
|
||||
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
|
||||
input_violation_message = await run_guardrails(self.safety_api, combined_text, self.guardrail_ids)
|
||||
if input_violation_message:
|
||||
logger.info(f"Input guardrail violation: {input_violation_message}")
|
||||
yield await self._create_refusal_response(input_violation_message)
|
||||
return
|
||||
|
||||
async for stream_event in self._process_tools(output_messages):
|
||||
yield stream_event
|
||||
|
||||
|
@ -201,6 +236,11 @@ class StreamingResponseOrchestrator:
|
|||
completion_result_data = stream_event_or_result
|
||||
else:
|
||||
yield stream_event_or_result
|
||||
|
||||
# If violation detected, skip the rest of processing since we already sent refusal
|
||||
if self.violation_detected:
|
||||
return
|
||||
|
||||
if not completion_result_data:
|
||||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
last_completion_result = completion_result_data
|
||||
|
@ -525,6 +565,9 @@ class StreamingResponseOrchestrator:
|
|||
# Accumulate usage from chunks (typically in final chunk with stream_options)
|
||||
self._accumulate_chunk_usage(chunk)
|
||||
|
||||
# Track deltas for this specific chunk for guardrail validation
|
||||
chunk_events: list[OpenAIResponseObjectStream] = []
|
||||
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
|
@ -560,13 +603,19 @@ class StreamingResponseOrchestrator:
|
|||
sequence_number=self.sequence_number,
|
||||
)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
|
||||
text_delta_event = OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=content_index,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Buffer text delta events for guardrail check
|
||||
if self.guardrail_ids:
|
||||
chunk_events.append(text_delta_event)
|
||||
else:
|
||||
yield text_delta_event
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
|
@ -582,7 +631,11 @@ class StreamingResponseOrchestrator:
|
|||
message_item_id=message_item_id,
|
||||
message_output_index=message_output_index,
|
||||
):
|
||||
yield event
|
||||
# Buffer reasoning events for guardrail check
|
||||
if self.guardrail_ids:
|
||||
chunk_events.append(event)
|
||||
else:
|
||||
yield event
|
||||
reasoning_part_emitted = True
|
||||
reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content)
|
||||
|
||||
|
@ -664,6 +717,22 @@ class StreamingResponseOrchestrator:
|
|||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
||||
# Output Safety Validation for this chunk
|
||||
if self.guardrail_ids:
|
||||
# Check guardrails on accumulated text so far
|
||||
accumulated_text = "".join(chat_response_content)
|
||||
violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
||||
if violation_message:
|
||||
logger.info(f"Output guardrail violation: {violation_message}")
|
||||
chunk_events.clear()
|
||||
yield await self._create_refusal_response(violation_message)
|
||||
self.violation_detected = True
|
||||
return
|
||||
else:
|
||||
# No violation detected, emit all content events for this chunk
|
||||
for event in chunk_events:
|
||||
yield event
|
||||
|
||||
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call = chat_response_tool_calls[tool_call_index]
|
||||
|
|
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInput,
|
||||
|
@ -45,6 +47,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(
|
||||
|
@ -240,7 +243,8 @@ async def convert_response_text_to_chat_response_format(
|
|||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def get_message_type_by_role(role: str):
|
||||
async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
|
||||
"""Get the appropriate OpenAI message parameter type for a given role."""
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
|
@ -307,3 +311,55 @@ def is_function_tool_call(
|
|||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||
"""Run guardrails against messages and return violation message if blocked."""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
shields_list = await safety_api.routing_table.list_shields()
|
||||
|
||||
for guardrail_id in guardrail_ids:
|
||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||
if matching_shields:
|
||||
model_id = matching_shields[0].provider_resource_id
|
||||
model_ids.append(model_id)
|
||||
else:
|
||||
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
||||
|
||||
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
|
||||
responses = await asyncio.gather(*guardrail_tasks)
|
||||
|
||||
for response in responses:
|
||||
for result in response.results:
|
||||
if result.flagged:
|
||||
message = result.user_message or "Content blocked by safety guardrails"
|
||||
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
|
||||
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||
|
||||
if flagged_categories:
|
||||
message += f" (flagged for: {', '.join(flagged_categories)})"
|
||||
if violation_type:
|
||||
message += f" (violation type: {', '.join(violation_type)})"
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
|
||||
if not guardrails:
|
||||
return []
|
||||
|
||||
guardrail_ids = []
|
||||
for guardrail in guardrails:
|
||||
if isinstance(guardrail, str):
|
||||
guardrail_ids.append(guardrail)
|
||||
elif isinstance(guardrail, ResponseGuardrailSpec):
|
||||
guardrail_ids.append(guardrail.type)
|
||||
else:
|
||||
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
|
||||
|
||||
return guardrail_ids
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue