mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
clean
This commit is contained in:
parent
da07772480
commit
b5c08c72a7
4 changed files with 33 additions and 141 deletions
|
|
@ -64,7 +64,6 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIChatCompletionToolCall,
|
OpenAIChatCompletionToolCall,
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIUserMessageParam,
|
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
|
|
@ -136,33 +135,16 @@ class StreamingResponseOrchestrator:
|
||||||
# Track if we've sent a refusal response
|
# Track if we've sent a refusal response
|
||||||
self.violation_detected = False
|
self.violation_detected = False
|
||||||
|
|
||||||
async def _check_input_safety(
|
async def _apply_guardrails(self, text: str, context: str = "content") -> str | None:
|
||||||
self, messages: list[OpenAIUserMessageParam]
|
"""Apply guardrails to text content. Returns violation message if blocked."""
|
||||||
) -> OpenAIResponseContentPartRefusal | None:
|
if not self.guardrail_ids or not text:
|
||||||
"""Validate input messages against guardrails. Returns refusal content if violation found."""
|
|
||||||
combined_text = interleaved_content_as_str([msg.content for msg in messages])
|
|
||||||
|
|
||||||
if not combined_text:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await run_multiple_guardrails(self.safety_api, combined_text, self.guardrail_ids)
|
await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids)
|
||||||
except SafetyException as e:
|
except SafetyException as e:
|
||||||
logger.info(f"Input guardrail violation: {e.violation.user_message}")
|
logger.info(f"{context.capitalize()} guardrail violation: {e.violation.user_message}")
|
||||||
return OpenAIResponseContentPartRefusal(
|
return e.violation.user_message or f"{context.capitalize()} blocked by safety guardrails"
|
||||||
refusal=e.violation.user_message or "Content blocked by safety guardrails"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
|
|
||||||
"""Check accumulated streaming text content against guardrails. Returns violation message if blocked."""
|
|
||||||
if not self.guardrail_ids or not accumulated_text:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
await run_multiple_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
|
||||||
except SafetyException as e:
|
|
||||||
logger.info(f"Output guardrail violation: {e.violation.user_message}")
|
|
||||||
return e.violation.user_message or "Generated content blocked by safety guardrails"
|
|
||||||
|
|
||||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||||
"""Create a refusal response to replace streaming content."""
|
"""Create a refusal response to replace streaming content."""
|
||||||
|
|
@ -224,10 +206,11 @@ class StreamingResponseOrchestrator:
|
||||||
|
|
||||||
# Input safety validation - check messages before processing
|
# Input safety validation - check messages before processing
|
||||||
if self.guardrail_ids:
|
if self.guardrail_ids:
|
||||||
input_refusal = await self._check_input_safety(self.ctx.messages)
|
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
|
||||||
if input_refusal:
|
input_violation_message = await self._apply_guardrails(combined_text, "input")
|
||||||
|
if input_violation_message:
|
||||||
# Return refusal response immediately
|
# Return refusal response immediately
|
||||||
yield await self._create_refusal_response(input_refusal.refusal)
|
yield await self._create_refusal_response(input_violation_message)
|
||||||
return
|
return
|
||||||
|
|
||||||
async for stream_event in self._process_tools(output_messages):
|
async for stream_event in self._process_tools(output_messages):
|
||||||
|
|
@ -733,10 +716,10 @@ class StreamingResponseOrchestrator:
|
||||||
response_tool_call.function.arguments or ""
|
response_tool_call.function.arguments or ""
|
||||||
) + tool_call.function.arguments
|
) + tool_call.function.arguments
|
||||||
|
|
||||||
# Safety check after processing all choices in this chunk
|
# Output Safety Validation for a chunk
|
||||||
if chat_response_content:
|
if chat_response_content:
|
||||||
accumulated_text = "".join(chat_response_content)
|
accumulated_text = "".join(chat_response_content)
|
||||||
violation_message = await self._check_output_stream_chunk_safety(accumulated_text)
|
violation_message = await self._apply_guardrails(accumulated_text, "output")
|
||||||
if violation_message:
|
if violation_message:
|
||||||
yield await self._create_refusal_response(violation_message)
|
yield await self._create_refusal_response(violation_message)
|
||||||
self.violation_detected = True
|
self.violation_detected = True
|
||||||
|
|
|
||||||
|
|
@ -365,20 +365,3 @@ def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||||
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
|
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
|
||||||
|
|
||||||
return guardrail_ids
|
return guardrail_ids
|
||||||
|
|
||||||
|
|
||||||
def extract_text_content(content: str | list | None) -> str | None:
|
|
||||||
"""Extract text content from OpenAI message content (string or complex structure)."""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
elif isinstance(content, list):
|
|
||||||
# Handle complex content - extract text parts only
|
|
||||||
text_parts = []
|
|
||||||
for part in content:
|
|
||||||
if hasattr(part, "text"):
|
|
||||||
text_parts.append(part.text)
|
|
||||||
elif hasattr(part, "type") and part.type == "refusal":
|
|
||||||
# Skip refusal parts - don't validate them again
|
|
||||||
continue
|
|
||||||
return " ".join(text_parts) if text_parts else None
|
|
||||||
return None
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -14,7 +14,6 @@ from llama_stack.providers.inline.agents.meta_reference.responses.openai_respons
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||||
extract_guardrail_ids,
|
extract_guardrail_ids,
|
||||||
extract_text_content,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -85,76 +84,3 @@ def test_extract_guardrail_ids_unknown_format(responses_impl):
|
||||||
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
||||||
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
||||||
extract_guardrail_ids(guardrails)
|
extract_guardrail_ids(guardrails)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_text_content_string(responses_impl):
|
|
||||||
"""Test extraction from simple string content."""
|
|
||||||
content = "Hello world"
|
|
||||||
result = extract_text_content(content)
|
|
||||||
assert result == "Hello world"
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_text_content_list_with_text(responses_impl):
|
|
||||||
"""Test extraction from list content with text parts."""
|
|
||||||
content = [
|
|
||||||
MagicMock(text="Hello "),
|
|
||||||
MagicMock(text="world"),
|
|
||||||
]
|
|
||||||
result = extract_text_content(content)
|
|
||||||
assert result == "Hello world"
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_text_content_list_with_refusal(responses_impl):
|
|
||||||
"""Test extraction skips refusal parts."""
|
|
||||||
# Create text parts
|
|
||||||
text_part1 = MagicMock()
|
|
||||||
text_part1.text = "Hello"
|
|
||||||
|
|
||||||
text_part2 = MagicMock()
|
|
||||||
text_part2.text = "world"
|
|
||||||
|
|
||||||
# Create refusal part (no text attribute)
|
|
||||||
refusal_part = MagicMock()
|
|
||||||
refusal_part.type = "refusal"
|
|
||||||
refusal_part.refusal = "Blocked"
|
|
||||||
del refusal_part.text # Remove text attribute
|
|
||||||
|
|
||||||
content = [text_part1, refusal_part, text_part2]
|
|
||||||
result = extract_text_content(content)
|
|
||||||
assert result == "Hello world"
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_text_content_empty_list(responses_impl):
|
|
||||||
"""Test extraction from empty list returns None."""
|
|
||||||
content = []
|
|
||||||
result = extract_text_content(content)
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_text_content_no_text_parts(responses_impl):
|
|
||||||
"""Test extraction with no text parts returns None."""
|
|
||||||
# Create image part (no text attribute)
|
|
||||||
image_part = MagicMock()
|
|
||||||
image_part.type = "image"
|
|
||||||
image_part.image_url = "http://example.com"
|
|
||||||
|
|
||||||
# Create refusal part (no text attribute)
|
|
||||||
refusal_part = MagicMock()
|
|
||||||
refusal_part.type = "refusal"
|
|
||||||
refusal_part.refusal = "Blocked"
|
|
||||||
|
|
||||||
# Explicitly remove text attributes to simulate non-text parts
|
|
||||||
if hasattr(image_part, "text"):
|
|
||||||
delattr(image_part, "text")
|
|
||||||
if hasattr(refusal_part, "text"):
|
|
||||||
delattr(refusal_part, "text")
|
|
||||||
|
|
||||||
content = [image_part, refusal_part]
|
|
||||||
result = extract_text_content(content)
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_text_content_none_input(responses_impl):
|
|
||||||
"""Test extraction with None input returns None."""
|
|
||||||
result = extract_text_content(None)
|
|
||||||
assert result is None
|
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,9 @@ from unittest.mock import AsyncMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseContentPartRefusal,
|
|
||||||
OpenAIResponseText,
|
OpenAIResponseText,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import UserMessage
|
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
||||||
StreamingResponseOrchestrator,
|
StreamingResponseOrchestrator,
|
||||||
|
|
@ -79,12 +78,12 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
|
||||||
assert tags_param["items"] == {"type": "string"}
|
assert tags_param["items"] == {"type": "string"}
|
||||||
|
|
||||||
|
|
||||||
async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_api, mock_context):
|
async def test_apply_guardrails_no_violation(mock_safety_api, mock_inference_api, mock_context):
|
||||||
"""Test input shield validation with no violations."""
|
"""Test guardrails validation with no violations."""
|
||||||
messages = [UserMessage(content="Hello world")]
|
text = "Hello world"
|
||||||
guardrail_ids = ["llama-guard"]
|
guardrail_ids = ["llama-guard"]
|
||||||
|
|
||||||
# Mock successful shield validation (no violation)
|
# Mock successful guardrails validation (no violation)
|
||||||
mock_response = AsyncMock()
|
mock_response = AsyncMock()
|
||||||
mock_response.violation = None
|
mock_response.violation = None
|
||||||
mock_safety_api.run_shield.return_value = mock_response
|
mock_safety_api.run_shield.return_value = mock_response
|
||||||
|
|
@ -102,7 +101,7 @@ async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_a
|
||||||
guardrail_ids=guardrail_ids,
|
guardrail_ids=guardrail_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await orchestrator._check_input_safety(messages)
|
result = await orchestrator._apply_guardrails(text)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
# Verify run_moderation was called with the correct model
|
# Verify run_moderation was called with the correct model
|
||||||
|
|
@ -112,13 +111,15 @@ async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_a
|
||||||
assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock
|
assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock
|
||||||
|
|
||||||
|
|
||||||
async def test_check_input_safety_with_violation(mock_safety_api, mock_inference_api, mock_context):
|
async def test_apply_guardrails_with_violation(mock_safety_api, mock_inference_api, mock_context):
|
||||||
"""Test input shield validation with safety violation."""
|
"""Test guardrails validation with safety violation."""
|
||||||
messages = [UserMessage(content="Harmful content")]
|
text = "Harmful content"
|
||||||
guardrail_ids = ["llama-guard"]
|
guardrail_ids = ["llama-guard"]
|
||||||
|
|
||||||
# Mock moderation to return flagged content
|
# Mock moderation to return flagged content
|
||||||
mock_safety_api.run_moderation.return_value = AsyncMock(flagged=True, categories={"violence": True})
|
flagged_result = ModerationObjectResults(flagged=True, categories={"violence": True})
|
||||||
|
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
|
||||||
|
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||||
|
|
||||||
# Create orchestrator with safety components
|
# Create orchestrator with safety components
|
||||||
orchestrator = StreamingResponseOrchestrator(
|
orchestrator = StreamingResponseOrchestrator(
|
||||||
|
|
@ -133,14 +134,13 @@ async def test_check_input_safety_with_violation(mock_safety_api, mock_inference
|
||||||
guardrail_ids=guardrail_ids,
|
guardrail_ids=guardrail_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await orchestrator._check_input_safety(messages)
|
result = await orchestrator._apply_guardrails(text)
|
||||||
|
|
||||||
assert isinstance(result, OpenAIResponseContentPartRefusal)
|
assert result == "Content flagged by moderation"
|
||||||
assert result.refusal == "Content flagged by moderation"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
|
async def test_apply_guardrails_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
|
||||||
"""Test input shield validation with empty inputs."""
|
"""Test guardrails validation with empty inputs."""
|
||||||
# Create orchestrator with safety components
|
# Create orchestrator with safety components
|
||||||
orchestrator = StreamingResponseOrchestrator(
|
orchestrator = StreamingResponseOrchestrator(
|
||||||
inference_api=mock_inference_api,
|
inference_api=mock_inference_api,
|
||||||
|
|
@ -154,11 +154,11 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a
|
||||||
guardrail_ids=[],
|
guardrail_ids=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test empty shield_ids
|
# Test empty guardrail_ids
|
||||||
result = await orchestrator._check_input_safety([UserMessage(content="test")])
|
result = await orchestrator._apply_guardrails("test")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
# Test empty messages
|
# Test empty text
|
||||||
orchestrator.guardrail_ids = ["llama-guard"]
|
orchestrator.guardrail_ids = ["llama-guard"]
|
||||||
result = await orchestrator._check_input_safety([])
|
result = await orchestrator._apply_guardrails("")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue