mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 15:22:31 +00:00
use guardrails and run_moderation api
This commit is contained in:
parent
171fb7101d
commit
c10db23d7a
16 changed files with 184 additions and 195 deletions
|
|
@ -300,13 +300,13 @@ def test_function_call_output_response_with_none_arguments(openai_client, client
|
|||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_input_safety_shields(compat_client, text_model_id, stream):
|
||||
"""Test that input safety shields block inappropriate content."""
|
||||
def test_input_safety_guardrails(compat_client, text_model_id, stream):
|
||||
"""Test that input safety guardrails block inappropriate content."""
|
||||
if not isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
||||
|
||||
# Test with shield that should block inappropriate content
|
||||
# Note: This test assumes shields are configured in the test environment
|
||||
# Test with guardrail that should block inappropriate content
|
||||
# Note: This test assumes guardrails are configured in the test environment
|
||||
inappropriate_message = "How can I hurt someone?"
|
||||
|
||||
response = compat_client.responses.create(
|
||||
|
|
@ -318,7 +318,7 @@ def test_input_safety_shields(compat_client, text_model_id, stream):
|
|||
}
|
||||
],
|
||||
stream=stream,
|
||||
extra_body={"shields": ["llama-guard"]}, # Shields passed via extra_body
|
||||
extra_body={"guardrails": ["llama-guard"]}, # Guardrails passed via extra_body
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -334,11 +334,11 @@ def test_input_safety_shields(compat_client, text_model_id, stream):
|
|||
else:
|
||||
response_to_check = response
|
||||
|
||||
# Check if the response contains a refusal (indicating shield blocked content)
|
||||
# Check if the response contains a refusal (indicating guardrail blocked content)
|
||||
assert response_to_check.output[0].type == "message"
|
||||
message = response_to_check.output[0]
|
||||
|
||||
# Shield should ALWAYS block inappropriate input - expect refusal response
|
||||
# Guardrail should ALWAYS block inappropriate input - expect refusal response
|
||||
assert len(message.content) > 0, "Content should not be empty"
|
||||
content_item = message.content[0]
|
||||
assert content_item.type == "refusal", f"Expected refusal content, but got: {content_item.type}"
|
||||
|
|
@ -346,16 +346,16 @@ def test_input_safety_shields(compat_client, text_model_id, stream):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_output_safety_shields(compat_client, text_model_id, stream):
|
||||
"""Test that output safety shields can block generated content."""
|
||||
def test_output_safety_guardrails(compat_client, text_model_id, stream):
|
||||
"""Test that output safety guardrails can block generated content."""
|
||||
if not isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
||||
|
||||
# Use an ambiguous prompt that could potentially generate content requiring shield evaluation
|
||||
# Use an ambiguous prompt that could potentially generate content requiring guardrail evaluation
|
||||
# This prompt: "Describe violent content."
|
||||
# Could result in:
|
||||
# 1. Educational violent content → shield allows output → normal response
|
||||
# 2. Violent techniques content → shield blocks output → refusal response
|
||||
# 1. Educational violent content → guardrail allows output → normal response
|
||||
# 2. Violent techniques content → guardrail blocks output → refusal response
|
||||
response = compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=[
|
||||
|
|
@ -365,7 +365,7 @@ def test_output_safety_shields(compat_client, text_model_id, stream):
|
|||
}
|
||||
],
|
||||
stream=stream,
|
||||
extra_body={"shields": ["llama-guard"]}, # Output shield validation
|
||||
extra_body={"guardrails": ["llama-guard"]}, # Output guardrail validation
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -388,8 +388,8 @@ def test_output_safety_shields(compat_client, text_model_id, stream):
|
|||
assert content_item.type == "refusal", f"Content type should be 'refusal', got {content_item.type}"
|
||||
|
||||
|
||||
def test_shields_with_tools(compat_client, text_model_id):
|
||||
"""Test that shields work correctly when tools are present."""
|
||||
def test_guardrails_with_tools(compat_client, text_model_id):
|
||||
"""Test that guardrails work correctly when tools are present."""
|
||||
if not isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
||||
|
||||
|
|
@ -414,11 +414,11 @@ def test_shields_with_tools(compat_client, text_model_id):
|
|||
},
|
||||
}
|
||||
],
|
||||
extra_body={"shields": ["llama-guard"]},
|
||||
extra_body={"guardrails": ["llama-guard"]},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify response completes successfully with tools and shields
|
||||
# Verify response completes successfully with tools and guardrails
|
||||
assert response.id is not None
|
||||
assert len(response.output) > 0
|
||||
|
||||
|
|
|
|||
|
|
@ -1,34 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test for extra_body parameter support with shields example.
|
||||
|
||||
This test demonstrates that parameters marked with ExtraBodyField annotation
|
||||
can be passed via extra_body in the client SDK and are received by the
|
||||
server-side implementation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import APIStatusError
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Shields are not yet implemented inside responses")
|
||||
def test_shields_via_extra_body(compat_client, text_model_id):
|
||||
"""Test that shields parameter is received by the server and raises NotImplementedError."""
|
||||
|
||||
# Test with shields as list of strings (shield IDs)
|
||||
with pytest.raises((APIStatusError, NotImplementedError)) as exc_info:
|
||||
compat_client.responses.create(
|
||||
model=text_model_id,
|
||||
input="What is the capital of France?",
|
||||
stream=False,
|
||||
extra_body={"shields": ["test-shield-1", "test-shield-2"]},
|
||||
)
|
||||
|
||||
# Verify the error message indicates shields are not implemented
|
||||
error_message = str(exc_info.value)
|
||||
assert "not yet implemented" in error_message.lower() or "not implemented" in error_message.lower()
|
||||
|
|
@ -8,12 +8,12 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseShieldSpec
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
extract_shield_ids,
|
||||
extract_guardrail_ids,
|
||||
extract_text_content,
|
||||
)
|
||||
|
||||
|
|
@ -38,53 +38,53 @@ def responses_impl(mock_apis):
|
|||
return OpenAIResponsesImpl(**mock_apis)
|
||||
|
||||
|
||||
def test_extract_shield_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string shield IDs."""
|
||||
shields = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_shield_ids(shields)
|
||||
def test_extract_guardrail_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string guardrail IDs."""
|
||||
guardrails = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseShieldSpec objects."""
|
||||
shields = [
|
||||
ResponseShieldSpec(type="llama-guard"),
|
||||
ResponseShieldSpec(type="content-filter"),
|
||||
def test_extract_guardrail_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseGuardrailSpec objects."""
|
||||
guardrails = [
|
||||
ResponseGuardrailSpec(type="llama-guard"),
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
]
|
||||
result = extract_shield_ids(shields)
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_mixed_formats(responses_impl):
|
||||
def test_extract_guardrail_ids_mixed_formats(responses_impl):
|
||||
"""Test extraction from mixed string and object formats."""
|
||||
shields = [
|
||||
guardrails = [
|
||||
"llama-guard",
|
||||
ResponseShieldSpec(type="content-filter"),
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
"nsfw-detector",
|
||||
]
|
||||
result = extract_shield_ids(shields)
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_none_input(responses_impl):
|
||||
def test_extract_guardrail_ids_none_input(responses_impl):
|
||||
"""Test extraction with None input."""
|
||||
result = extract_shield_ids(None)
|
||||
result = extract_guardrail_ids(None)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_shield_ids_empty_list(responses_impl):
|
||||
def test_extract_guardrail_ids_empty_list(responses_impl):
|
||||
"""Test extraction with empty list."""
|
||||
result = extract_shield_ids([])
|
||||
result = extract_guardrail_ids([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_shield_ids_unknown_format(responses_impl):
|
||||
"""Test extraction with unknown shield format raises ValueError."""
|
||||
# Create an object that's neither string nor ResponseShieldSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec
|
||||
shields = ["valid-shield", unknown_object, "another-shield"]
|
||||
with pytest.raises(ValueError, match="Unknown shield format.*expected str or ResponseShieldSpec"):
|
||||
extract_shield_ids(shields)
|
||||
def test_extract_guardrail_ids_unknown_format(responses_impl):
|
||||
"""Test extraction with unknown guardrail format raises ValueError."""
|
||||
# Create an object that's neither string nor ResponseGuardrailSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec
|
||||
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
||||
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
||||
extract_guardrail_ids(guardrails)
|
||||
|
||||
|
||||
def test_extract_text_content_string(responses_impl):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue