mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 14:12:49 +00:00
use guardrails and run_moderation api
This commit is contained in:
parent
171fb7101d
commit
c10db23d7a
16 changed files with 184 additions and 195 deletions
|
|
@ -8,12 +8,12 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseShieldSpec
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
extract_shield_ids,
|
||||
extract_guardrail_ids,
|
||||
extract_text_content,
|
||||
)
|
||||
|
||||
|
|
@ -38,53 +38,53 @@ def responses_impl(mock_apis):
|
|||
return OpenAIResponsesImpl(**mock_apis)
|
||||
|
||||
|
||||
def test_extract_shield_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string shield IDs."""
|
||||
shields = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_shield_ids(shields)
|
||||
def test_extract_guardrail_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string guardrail IDs."""
|
||||
guardrails = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseShieldSpec objects."""
|
||||
shields = [
|
||||
ResponseShieldSpec(type="llama-guard"),
|
||||
ResponseShieldSpec(type="content-filter"),
|
||||
def test_extract_guardrail_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseGuardrailSpec objects."""
|
||||
guardrails = [
|
||||
ResponseGuardrailSpec(type="llama-guard"),
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
]
|
||||
result = extract_shield_ids(shields)
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_mixed_formats(responses_impl):
|
||||
def test_extract_guardrail_ids_mixed_formats(responses_impl):
|
||||
"""Test extraction from mixed string and object formats."""
|
||||
shields = [
|
||||
guardrails = [
|
||||
"llama-guard",
|
||||
ResponseShieldSpec(type="content-filter"),
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
"nsfw-detector",
|
||||
]
|
||||
result = extract_shield_ids(shields)
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_none_input(responses_impl):
|
||||
def test_extract_guardrail_ids_none_input(responses_impl):
|
||||
"""Test extraction with None input."""
|
||||
result = extract_shield_ids(None)
|
||||
result = extract_guardrail_ids(None)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_shield_ids_empty_list(responses_impl):
|
||||
def test_extract_guardrail_ids_empty_list(responses_impl):
|
||||
"""Test extraction with empty list."""
|
||||
result = extract_shield_ids([])
|
||||
result = extract_guardrail_ids([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_shield_ids_unknown_format(responses_impl):
|
||||
"""Test extraction with unknown shield format raises ValueError."""
|
||||
# Create an object that's neither string nor ResponseShieldSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec
|
||||
shields = ["valid-shield", unknown_object, "another-shield"]
|
||||
with pytest.raises(ValueError, match="Unknown shield format.*expected str or ResponseShieldSpec"):
|
||||
extract_shield_ids(shields)
|
||||
def test_extract_guardrail_ids_unknown_format(responses_impl):
|
||||
"""Test extraction with unknown guardrail format raises ValueError."""
|
||||
# Create an object that's neither string nor ResponseGuardrailSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec
|
||||
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
||||
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
||||
extract_guardrail_ids(guardrails)
|
||||
|
||||
|
||||
def test_extract_text_content_string(responses_impl):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue