address comments

This commit is contained in:
Swapna Lekkala 2025-10-10 12:56:32 -07:00
parent c66757ea4d
commit 0efdc46d89
25 changed files with 1251 additions and 77 deletions

View file

@ -34,7 +34,6 @@ from llama_stack.apis.conversations import Conversations
from llama_stack.apis.conversations.conversations import ConversationItem
from llama_stack.apis.inference import (
Inference,
Message,
OpenAIMessageParam,
OpenAISystemMessageParam,
)
@ -47,7 +46,6 @@ from llama_stack.providers.utils.responses.responses_store import (
_OpenAIResponseObjectWithInputAndMessages,
)
from ..safety import SafetyException
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
from .types import ChatCompletionContext, ToolContext
@ -55,7 +53,6 @@ from .utils import (
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
extract_shield_ids,
run_multiple_shields,
)
logger = get_logger(name=__name__, category="openai_responses")
@ -297,18 +294,6 @@ class OpenAIResponsesImpl:
raise ValueError("The response stream never reached a terminal state")
return final_response
async def _check_input_safety(
self, messages: list[Message], shield_ids: list[str]
) -> OpenAIResponseContentPartRefusal | None:
"""Validate input messages against shields. Returns refusal content if violation found."""
try:
await run_multiple_shields(self.safety_api, messages, shield_ids)
except SafetyException as e:
logger.info(f"Input shield violation: {e.violation.user_message}")
return OpenAIResponseContentPartRefusal(
refusal=e.violation.user_message or "Content blocked by safety shields"
)
async def _create_refusal_response_events(
self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str
) -> AsyncIterator[OpenAIResponseObjectStream]:

View file

@ -49,6 +49,7 @@ from llama_stack.apis.agents.openai_responses import (
from llama_stack.apis.inference import (
CompletionMessage,
Inference,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
@ -126,7 +127,7 @@ class StreamingResponseOrchestrator:
# Track if we've sent a refusal response
self.violation_detected = False
async def _check_input_safety(self, messages: list[OpenAIMessageParam]) -> OpenAIResponseContentPartRefusal | None:
async def _check_input_safety(self, messages: list[Message]) -> OpenAIResponseContentPartRefusal | None:
"""Validate input messages against shields. Returns refusal content if violation found."""
try:
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
@ -141,13 +142,12 @@ class StreamingResponseOrchestrator:
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Create refusal response events for input safety violations."""
# Create the refusal content part explicitly with the correct structure
refusal_part = OpenAIResponseContentPartRefusal(refusal=refusal_content.refusal, type="refusal")
refusal_response = OpenAIResponseObject(
id=self.response_id,
created_at=self.created_at,
model=self.ctx.model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_part], type="message")],
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
@ -557,7 +557,7 @@ class StreamingResponseOrchestrator:
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Safety check after processing all chunks
# Safety check after processing all choices in this chunk
if chat_response_content:
accumulated_text = "".join(chat_response_content)
violation_message = await self._check_output_stream_chunk_safety(accumulated_text)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import re
import uuid
@ -317,12 +318,13 @@ async def run_multiple_shields(safety_api: Safety, messages: list[Message], shie
"""Run multiple shields against messages and raise SafetyException for violations."""
if not shield_ids or not messages:
return
for shield_id in shield_ids:
response = await safety_api.run_shield(
shield_id=shield_id,
messages=messages,
params={},
)
shield_tasks = [
safety_api.run_shield(shield_id=shield_id, messages=messages, params={}) for shield_id in shield_ids
]
responses = await asyncio.gather(*shield_tasks)
for response in responses:
if response.violation and response.violation.violation_level.name == "ERROR":
from ..safety import SafetyException