mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-22 16:23:08 +00:00
feat: Add responses and safety impl with extra body
This commit is contained in:
parent
548ccff368
commit
e09401805f
15 changed files with 877 additions and 9 deletions
|
@ -88,6 +88,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_runtime_api=self.tool_runtime_api,
|
||||
responses_store=self.responses_store,
|
||||
vector_io_api=self.vector_io_api,
|
||||
safety_api=self.safety_api,
|
||||
)
|
||||
|
||||
async def create_agent(
|
||||
|
|
|
@ -15,20 +15,25 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
Message,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -37,12 +42,16 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
|
||||
from ..safety import SafetyException
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext, ToolContext
|
||||
from .utils import (
|
||||
convert_openai_to_inference_messages,
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
extract_shield_ids,
|
||||
run_multiple_shields,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
@ -61,12 +70,14 @@ class OpenAIResponsesImpl:
|
|||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
safety_api: Safety,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.safety_api = safety_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
tool_runtime_api=tool_runtime_api,
|
||||
|
@ -217,9 +228,7 @@ class OpenAIResponsesImpl:
|
|||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
# Shields parameter received via extra_body - not yet implemented
|
||||
if shields is not None:
|
||||
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
|
||||
shield_ids = extract_shield_ids(shields) if shields else []
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
|
@ -231,6 +240,7 @@ class OpenAIResponsesImpl:
|
|||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
shield_ids=shield_ids,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
@ -264,6 +274,42 @@ class OpenAIResponsesImpl:
|
|||
raise ValueError("The response stream never reached a terminal state")
|
||||
return final_response
|
||||
|
||||
async def _check_input_safety(
|
||||
self, messages: list[Message], shield_ids: list[str]
|
||||
) -> OpenAIResponseContentPartRefusal | None:
|
||||
"""Validate input messages against shields. Returns refusal content if violation found."""
|
||||
try:
|
||||
await run_multiple_shields(self.safety_api, messages, shield_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Input shield violation: {e.violation.user_message}")
|
||||
return OpenAIResponseContentPartRefusal(
|
||||
refusal=e.violation.user_message or "Content blocked by safety shields"
|
||||
)
|
||||
|
||||
async def _create_refusal_response_events(
|
||||
self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create and yield refusal response events following the established streaming pattern."""
|
||||
# Create initial response and yield created event
|
||||
initial_response = OpenAIResponseObject(
|
||||
id=response_id,
|
||||
created_at=created_at,
|
||||
model=model,
|
||||
status="in_progress",
|
||||
output=[],
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
# Create completed refusal response using OpenAIResponseContentPartRefusal
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=response_id,
|
||||
created_at=created_at,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
|
@ -275,6 +321,7 @@ class OpenAIResponsesImpl:
|
|||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
shield_ids: list[str] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
|
@ -282,8 +329,23 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Input safety validation hook - validates messages before streaming orchestrator starts
|
||||
if shield_ids:
|
||||
input_messages = convert_openai_to_inference_messages(messages)
|
||||
input_refusal = await self._check_input_safety(input_messages, shield_ids)
|
||||
if input_refusal:
|
||||
# Return refusal response immediately
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
async for refusal_event in self._create_refusal_response_events(
|
||||
input_refusal, response_id, created_at, model
|
||||
):
|
||||
yield refusal_event
|
||||
return
|
||||
|
||||
# Structured outputs
|
||||
response_format = await convert_response_text_to_chat_response_format(text)
|
||||
response_format = convert_response_text_to_chat_response_format(text)
|
||||
|
||||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
|
@ -307,8 +369,11 @@ class OpenAIResponsesImpl:
|
|||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
safety_api=self.safety_api,
|
||||
shield_ids=shield_ids,
|
||||
)
|
||||
|
||||
# Output safety validation hook - delegated to streaming orchestrator for real-time validation
|
||||
# Stream the response
|
||||
final_response = None
|
||||
failed_response = None
|
||||
|
|
|
@ -14,9 +14,11 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseError,
|
||||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
|
@ -52,8 +54,14 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from ..safety import SafetyException
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||
from .utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
convert_openai_to_inference_messages,
|
||||
is_function_tool_call,
|
||||
run_multiple_shields,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
@ -89,6 +97,8 @@ class StreamingResponseOrchestrator:
|
|||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
safety_api,
|
||||
shield_ids: list[str] | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
|
@ -97,6 +107,8 @@ class StreamingResponseOrchestrator:
|
|||
self.text = text
|
||||
self.max_infer_iters = max_infer_iters
|
||||
self.tool_executor = tool_executor
|
||||
self.safety_api = safety_api
|
||||
self.shield_ids = shield_ids or []
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
||||
|
@ -104,6 +116,43 @@ class StreamingResponseOrchestrator:
|
|||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
self.citation_files: dict[str, str] = {}
|
||||
# Track accumulated text for shield validation
|
||||
self.accumulated_text = ""
|
||||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
|
||||
async def _check_output_stream_safety(self, text_delta: str) -> str | None:
|
||||
"""Check streaming text content against shields. Returns violation message if blocked."""
|
||||
if not self.shield_ids:
|
||||
return None
|
||||
|
||||
self.accumulated_text += text_delta
|
||||
|
||||
# Check accumulated text periodically for violations (every 50 characters or at word boundaries)
|
||||
if len(self.accumulated_text) > 50 or text_delta.endswith((" ", "\n", ".", "!", "?")):
|
||||
temp_messages = [{"role": "assistant", "content": self.accumulated_text}]
|
||||
messages = convert_openai_to_inference_messages(temp_messages)
|
||||
|
||||
try:
|
||||
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Output shield violation: {e.violation.user_message}")
|
||||
return e.violation.user_message or "Generated content blocked by safety shields"
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
|
||||
|
||||
# Create a completed refusal response
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=self.response_id,
|
||||
created_at=self.created_at,
|
||||
model=self.ctx.model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
|
||||
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
||||
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
|
||||
cloned: list[OpenAIResponseOutput] = []
|
||||
|
@ -326,6 +375,15 @@ class StreamingResponseOrchestrator:
|
|||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
# Check output stream safety before yielding content
|
||||
violation_message = await self._check_output_stream_safety(chunk_choice.delta.content)
|
||||
if violation_message:
|
||||
# Stop streaming and send refusal response
|
||||
yield await self._create_refusal_response(violation_message)
|
||||
self.violation_detected = True
|
||||
# Return immediately - no further processing needed
|
||||
return
|
||||
|
||||
# Emit content_part.added event for first text chunk
|
||||
if not content_part_emitted:
|
||||
content_part_emitted = True
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import re
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseShieldSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInput,
|
||||
|
@ -26,6 +27,8 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
|
@ -44,7 +47,19 @@ from llama_stack.apis.inference import (
|
|||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses_utils")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Message and Content Conversion Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(
|
||||
|
@ -171,7 +186,7 @@ async def convert_response_input_to_chat_messages(
|
|||
pass
|
||||
else:
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
message_type = get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
|
@ -240,7 +255,8 @@ async def convert_response_text_to_chat_response_format(
|
|||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def get_message_type_by_role(role: str):
|
||||
async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
|
||||
"""Get the appropriate OpenAI message parameter type for a given role."""
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
|
@ -307,3 +323,90 @@ def is_function_tool_call(
|
|||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Safety and Shield Validation Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None:
|
||||
"""Run multiple shields against messages and raise SafetyException for violations."""
|
||||
if not shield_ids or not messages:
|
||||
return
|
||||
|
||||
for shield_id in shield_ids:
|
||||
response = await safety_api.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params={},
|
||||
)
|
||||
if response.violation and response.violation.violation_level.name == "ERROR":
|
||||
from ..safety import SafetyException
|
||||
|
||||
raise SafetyException(response.violation)
|
||||
|
||||
|
||||
def extract_shield_ids(shields: list | None) -> list[str]:
|
||||
"""Extract shield IDs from shields parameter, handling both string IDs and ResponseShieldSpec objects."""
|
||||
if not shields:
|
||||
return []
|
||||
|
||||
shield_ids = []
|
||||
for shield in shields:
|
||||
if isinstance(shield, str):
|
||||
shield_ids.append(shield)
|
||||
elif isinstance(shield, ResponseShieldSpec):
|
||||
shield_ids.append(shield.type)
|
||||
else:
|
||||
logger.warning(f"Unknown shield format: {shield}")
|
||||
|
||||
return shield_ids
|
||||
|
||||
|
||||
def extract_text_content(content: str | list | None) -> str | None:
|
||||
"""Extract text content from OpenAI message content (string or complex structure)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Handle complex content - extract text parts only
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if hasattr(part, "text"):
|
||||
text_parts.append(part.text)
|
||||
elif hasattr(part, "type") and part.type == "refusal":
|
||||
# Skip refusal parts - don't validate them again
|
||||
continue
|
||||
return " ".join(text_parts) if text_parts else None
|
||||
return None
|
||||
|
||||
|
||||
def convert_openai_to_inference_messages(openai_messages: list) -> list[Message]:
|
||||
"""Convert OpenAI messages to inference API Message format."""
|
||||
safety_messages = []
|
||||
for msg in openai_messages:
|
||||
# Handle both object attributes and dictionary keys
|
||||
if hasattr(msg, "content") and hasattr(msg, "role"):
|
||||
text_content = extract_text_content(msg.content)
|
||||
role = msg.role
|
||||
elif isinstance(msg, dict) and "content" in msg and "role" in msg:
|
||||
text_content = extract_text_content(msg["content"])
|
||||
role = msg["role"]
|
||||
else:
|
||||
continue
|
||||
|
||||
if text_content:
|
||||
# Create appropriate message subclass based on role
|
||||
if role == "user":
|
||||
safety_messages.append(UserMessage(content=text_content))
|
||||
elif role == "system":
|
||||
safety_messages.append(SystemMessage(content=text_content))
|
||||
elif role == "assistant":
|
||||
safety_messages.append(
|
||||
CompletionMessage(
|
||||
content=text_content,
|
||||
stop_reason=StopReason.end_of_turn, # Default for safety validation
|
||||
)
|
||||
)
|
||||
# Note: Skip "tool" role messages as they're not typically validated by shields
|
||||
return safety_messages
|
||||
|
|
|
@ -247,12 +247,17 @@ class LlamaGuardShield:
|
|||
self.safety_categories = safety_categories
|
||||
|
||||
def check_unsafe_response(self, response: str) -> str | None:
|
||||
# Check for "unsafe\n<code>" format
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
# extracts the unsafe code
|
||||
extracted = match.group(1)
|
||||
return extracted
|
||||
|
||||
# Check for direct category code format (e.g., "S1", "S2", etc.)
|
||||
if re.match(r"^S\d+$", response):
|
||||
return response
|
||||
|
||||
return None
|
||||
|
||||
def get_safety_categories(self) -> list[str]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue