mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-19 19:59:39 +00:00
address comments
This commit is contained in:
parent
c66757ea4d
commit
0efdc46d89
25 changed files with 1251 additions and 77 deletions
|
|
@ -34,7 +34,6 @@ from llama_stack.apis.conversations import Conversations
|
|||
from llama_stack.apis.conversations.conversations import ConversationItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
Message,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
|
|
@ -47,7 +46,6 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
|
||||
from ..safety import SafetyException
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext, ToolContext
|
||||
|
|
@ -55,7 +53,6 @@ from .utils import (
|
|||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
extract_shield_ids,
|
||||
run_multiple_shields,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
|
@ -297,18 +294,6 @@ class OpenAIResponsesImpl:
|
|||
raise ValueError("The response stream never reached a terminal state")
|
||||
return final_response
|
||||
|
||||
async def _check_input_safety(
|
||||
self, messages: list[Message], shield_ids: list[str]
|
||||
) -> OpenAIResponseContentPartRefusal | None:
|
||||
"""Validate input messages against shields. Returns refusal content if violation found."""
|
||||
try:
|
||||
await run_multiple_shields(self.safety_api, messages, shield_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Input shield violation: {e.violation.user_message}")
|
||||
return OpenAIResponseContentPartRefusal(
|
||||
refusal=e.violation.user_message or "Content blocked by safety shields"
|
||||
)
|
||||
|
||||
async def _create_refusal_response_events(
|
||||
self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
Inference,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
|
|
@ -126,7 +127,7 @@ class StreamingResponseOrchestrator:
|
|||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
|
||||
async def _check_input_safety(self, messages: list[OpenAIMessageParam]) -> OpenAIResponseContentPartRefusal | None:
|
||||
async def _check_input_safety(self, messages: list[Message]) -> OpenAIResponseContentPartRefusal | None:
|
||||
"""Validate input messages against shields. Returns refusal content if violation found."""
|
||||
try:
|
||||
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
|
||||
|
|
@ -141,13 +142,12 @@ class StreamingResponseOrchestrator:
|
|||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create refusal response events for input safety violations."""
|
||||
# Create the refusal content part explicitly with the correct structure
|
||||
refusal_part = OpenAIResponseContentPartRefusal(refusal=refusal_content.refusal, type="refusal")
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=self.response_id,
|
||||
created_at=self.created_at,
|
||||
model=self.ctx.model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_part], type="message")],
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
||||
|
|
@ -557,7 +557,7 @@ class StreamingResponseOrchestrator:
|
|||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
||||
# Safety check after processing all chunks
|
||||
# Safety check after processing all choices in this chunk
|
||||
if chat_response_content:
|
||||
accumulated_text = "".join(chat_response_content)
|
||||
violation_message = await self._check_output_stream_chunk_safety(accumulated_text)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
|
||||
|
|
@ -317,12 +318,13 @@ async def run_multiple_shields(safety_api: Safety, messages: list[Message], shie
|
|||
"""Run multiple shields against messages and raise SafetyException for violations."""
|
||||
if not shield_ids or not messages:
|
||||
return
|
||||
for shield_id in shield_ids:
|
||||
response = await safety_api.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params={},
|
||||
)
|
||||
shield_tasks = [
|
||||
safety_api.run_shield(shield_id=shield_id, messages=messages, params={}) for shield_id in shield_ids
|
||||
]
|
||||
|
||||
responses = await asyncio.gather(*shield_tasks)
|
||||
|
||||
for response in responses:
|
||||
if response.violation and response.violation.violation_level.name == "ERROR":
|
||||
from ..safety import SafetyException
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue