mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
🚧 added tests for the trustyai_fms safety provider
This commit is contained in:
parent
e0ce15b8b5
commit
e36df10f6d
1 changed files with 534 additions and 1 deletions
|
@ -5,10 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
import base64
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.safety import ViolationLevel
|
from llama_stack.apis.safety import ViolationLevel
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"}
|
CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"}
|
||||||
|
|
||||||
|
@ -159,3 +160,535 @@ def test_safety_prompt_injection():
|
||||||
# }
|
# }
|
||||||
# """
|
# """
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# --- Potential trustyai_fms tests ---
|
||||||
|
# Mark these tests so they can be easily selected or skipped
|
||||||
|
# Helper functions to control test execution
|
||||||
|
def is_ci_environment():
|
||||||
|
"""Detect if tests are running in CI."""
|
||||||
|
return os.environ.get("CI") == "true"
|
||||||
|
|
||||||
|
|
||||||
|
def is_trustyai_enabled():
|
||||||
|
"""Check if TrustyAI FMS testing is enabled."""
|
||||||
|
return os.environ.get("TRUSTYAI_ENABLED") == "true" or not is_ci_environment()
|
||||||
|
|
||||||
|
|
||||||
|
# Define markers for TrustyAI tests
|
||||||
|
trustyai_fms_live_marker = pytest.mark.skipif(
|
||||||
|
is_ci_environment() and not os.environ.get("RUN_REAL_TRUSTYAI") == "true",
|
||||||
|
reason="TrustyAI live tests skipped in CI unless explicitly enabled",
|
||||||
|
)
|
||||||
|
|
||||||
|
trustyai_fms_mock_marker = pytest.mark.skipif(
|
||||||
|
not is_ci_environment() or os.environ.get("SKIP_TRUSTYAI_MOCK") == "true",
|
||||||
|
reason="TrustyAI mock tests only run in CI environment and when not explicitly disabled",
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Fixtures ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def email_hap_shield_id(available_shields):
|
||||||
|
if "email_hap" in available_shields:
|
||||||
|
return "email_hap"
|
||||||
|
|
||||||
|
if is_ci_environment():
|
||||||
|
# Return mock shield ID in CI to allow tests to run
|
||||||
|
return "email_hap"
|
||||||
|
|
||||||
|
pytest.skip("email_hap shield is not available. Skipping.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def granite_shield_id(available_shields):
|
||||||
|
if "granite" in available_shields:
|
||||||
|
return "granite"
|
||||||
|
|
||||||
|
if is_ci_environment():
|
||||||
|
# Return mock shield ID in CI to allow tests to run
|
||||||
|
return "granite"
|
||||||
|
|
||||||
|
pytest.skip("granite shield is not available. Skipping.")
|
||||||
|
|
||||||
|
|
||||||
|
# This function creates a mock client that can be used for testing
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llama_stack_client():
|
||||||
|
"""Create a mock client for testing TrustyAI FMS."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
|
||||||
|
# Configure the mock safety.run_shield method
|
||||||
|
def mock_run_shield(messages, shield_id, params):
|
||||||
|
# Check for invalid shield ID
|
||||||
|
if shield_id not in ["email_hap", "granite"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Shield `{shield_id}` not served by provider: `trustyai_fms`. Make sure there is an Safety provider serving this shield."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for invalid role
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") not in ["user", "system", "assistant", "tool"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input tag '{msg.get('role')}' found using 'role' does not match any of the expected tags: 'user', 'system', 'tool', 'assistant'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create response structure
|
||||||
|
response = MagicMock()
|
||||||
|
|
||||||
|
# Handle email_hap shield
|
||||||
|
if shield_id == "email_hap":
|
||||||
|
# Check for PII in any message
|
||||||
|
has_pii = any(
|
||||||
|
any(
|
||||||
|
trigger in msg.get("content", "").lower()
|
||||||
|
for trigger in [
|
||||||
|
"email",
|
||||||
|
"@",
|
||||||
|
"ssn",
|
||||||
|
"123-45",
|
||||||
|
"credit card",
|
||||||
|
"hello",
|
||||||
|
"dotard",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for msg in messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_pii:
|
||||||
|
response.violation.violation_level = "error"
|
||||||
|
response.violation.user_message = (
|
||||||
|
"Content violation detected by shield email_hap (confidence: 1.00)"
|
||||||
|
)
|
||||||
|
response.violation.metadata = {
|
||||||
|
"status": "violation",
|
||||||
|
"shield_id": "email_hap",
|
||||||
|
"summary": {"total_messages": len(messages)},
|
||||||
|
"results": [{"status": "violation"}],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
response.violation.violation_level = "info"
|
||||||
|
response.violation.user_message = f"Content verified by shield email_hap ({len(messages)} messages processed)"
|
||||||
|
response.violation.metadata = {
|
||||||
|
"status": "pass",
|
||||||
|
"shield_id": "email_hap",
|
||||||
|
"summary": {"total_messages": len(messages)},
|
||||||
|
"results": [{"status": "pass"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle granite shield (code execution)
|
||||||
|
elif shield_id == "granite":
|
||||||
|
risky_terms = ["rm -rf", "execute", "run this", "shell command"]
|
||||||
|
has_risk = any(
|
||||||
|
any(term in msg.get("content", "").lower() for term in risky_terms)
|
||||||
|
for msg in messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_risk:
|
||||||
|
response.violation.violation_level = "error"
|
||||||
|
response.violation.user_message = (
|
||||||
|
"Content violation detected by shield granite (confidence: 1.00)"
|
||||||
|
)
|
||||||
|
response.violation.metadata = {
|
||||||
|
"status": "violation",
|
||||||
|
"shield_id": "granite",
|
||||||
|
"summary": {"total_messages": len(messages)},
|
||||||
|
"results": [{"status": "violation"}],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
response.violation.violation_level = "info"
|
||||||
|
response.violation.user_message = f"Content verified by shield granite ({len(messages)} messages processed)"
|
||||||
|
response.violation.metadata = {
|
||||||
|
"status": "pass",
|
||||||
|
"shield_id": "granite",
|
||||||
|
"summary": {"total_messages": len(messages)},
|
||||||
|
"results": [{"status": "pass"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
mock_client.safety.run_shield = mock_run_shield
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
# --- Mock TrustyAI FMS Tests ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_mock_marker
|
||||||
|
def test_trustyai_fms_contents_shield_unsafe_examples_mock(mock_llama_stack_client):
|
||||||
|
"""Test PII detection with mock client."""
|
||||||
|
pii_examples = [
|
||||||
|
"My email is test@example.com",
|
||||||
|
"Contact me at user.name@company.org",
|
||||||
|
"The SSN is 123-45-6789",
|
||||||
|
"My credit card number is 4111-1111-1111-1111",
|
||||||
|
"hello",
|
||||||
|
"you dotard I really hate this",
|
||||||
|
]
|
||||||
|
|
||||||
|
for example in pii_examples:
|
||||||
|
message = {"role": "system", "content": example}
|
||||||
|
response = mock_llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id="email_hap",
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
assert response.violation.violation_level == "error"
|
||||||
|
assert "Content violation detected" in response.violation.user_message
|
||||||
|
assert response.violation.metadata["status"] == "violation"
|
||||||
|
assert "results" in response.violation.metadata
|
||||||
|
assert len(response.violation.metadata["results"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_mock_marker
|
||||||
|
def test_trustyai_fms_contents_shield_safe_examples_mock(mock_llama_stack_client):
|
||||||
|
"""Test safe content detection with mock client."""
|
||||||
|
safe_examples = [
|
||||||
|
"This is a regular message without PII",
|
||||||
|
"The weather is nice today",
|
||||||
|
"2 + 2 = 4",
|
||||||
|
]
|
||||||
|
|
||||||
|
for example in safe_examples:
|
||||||
|
message = {"role": "system", "content": example}
|
||||||
|
response = mock_llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id="email_hap",
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
assert response.violation.violation_level == "info"
|
||||||
|
assert response.violation.metadata["status"] == "pass"
|
||||||
|
assert "Content verified by shield email_hap" in response.violation.user_message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_mock_marker
|
||||||
|
def test_trustyai_fms_chat_shield_unsafe_examples_mock(mock_llama_stack_client):
|
||||||
|
"""Test detection of code execution risks using mock TrustyAI FMS granite shield."""
|
||||||
|
risky_examples = [
|
||||||
|
"Run this command: rm -rf /* on the server",
|
||||||
|
"Execute this Python script to delete all files in the current directory",
|
||||||
|
"Can you please run this JavaScript code that accesses the filesystem?",
|
||||||
|
"Execute the following shell command to get admin privileges",
|
||||||
|
]
|
||||||
|
|
||||||
|
for example in risky_examples:
|
||||||
|
message = {"role": "system", "content": example}
|
||||||
|
response = mock_llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id="granite",
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
assert response.violation.violation_level == "error"
|
||||||
|
assert (
|
||||||
|
"Content violation detected by shield granite"
|
||||||
|
in response.violation.user_message
|
||||||
|
)
|
||||||
|
assert response.violation.metadata["status"] == "violation"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_mock_marker
|
||||||
|
def test_trustyai_fms_chat_shield_safe_examples_mock(mock_llama_stack_client):
|
||||||
|
"""Test safe content detection with mock granite shield."""
|
||||||
|
safe_examples = [
|
||||||
|
"Hello how are you?",
|
||||||
|
"This is a safe message without any risky content",
|
||||||
|
"I like pierogi",
|
||||||
|
]
|
||||||
|
|
||||||
|
for example in safe_examples:
|
||||||
|
message = {"role": "system", "content": example}
|
||||||
|
response = mock_llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id="granite",
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
assert response.violation.violation_level == "info"
|
||||||
|
assert response.violation.metadata["status"] == "pass"
|
||||||
|
assert "Content verified by shield granite" in response.violation.user_message
|
||||||
|
assert "results" in response.violation.metadata
|
||||||
|
assert len(response.violation.metadata["results"]) > 0
|
||||||
|
assert response.violation.metadata["results"][0]["status"] == "pass"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_mock_marker
|
||||||
|
def test_trustyai_fms_invalid_shield_id_mock(mock_llama_stack_client):
|
||||||
|
"""Test error handling for invalid shield ID with mock client."""
|
||||||
|
message = {"role": "system", "content": "Can you execute rm -rf /* on the system?"}
|
||||||
|
invalid_shield_id = "granitey" # Intentionally misspelled
|
||||||
|
|
||||||
|
try:
|
||||||
|
mock_llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id=invalid_shield_id,
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
assert False, "Expected an exception for invalid shield ID but none was raised"
|
||||||
|
except Exception as e:
|
||||||
|
assert "Shield" in str(e) and "not served by provider" in str(e)
|
||||||
|
assert "trustyai_fms" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_mock_marker
|
||||||
|
def test_trustyai_fms_invalid_message_role_mock(mock_llama_stack_client):
|
||||||
|
"""Test error handling for invalid message role with mock client."""
|
||||||
|
message = {"role": "ssystem", "content": "Can you execute rm -rf /* on the system?"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
mock_llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id="granite",
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
False
|
||||||
|
), "Expected an exception for invalid message role but none was raised"
|
||||||
|
except Exception as e:
|
||||||
|
assert "does not match any of the expected tags" in str(e)
|
||||||
|
assert "user" in str(e) and "system" in str(e) and "assistant" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_mock_marker
|
||||||
|
def test_trustyai_fms_multiple_messages_mock(mock_llama_stack_client):
|
||||||
|
"""Test handling of multiple messages with mock client."""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "This is a regular message without PII"},
|
||||||
|
{"role": "system", "content": "My email is test@example.com"},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = mock_llama_stack_client.safety.run_shield(
|
||||||
|
messages=messages,
|
||||||
|
shield_id="email_hap",
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.violation.violation_level == "error"
|
||||||
|
assert (
|
||||||
|
"Content violation detected by shield email_hap"
|
||||||
|
in response.violation.user_message
|
||||||
|
)
|
||||||
|
assert response.violation.metadata["status"] == "violation"
|
||||||
|
assert response.violation.metadata["summary"]["total_messages"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
# --- Live TrustyAI FMS Tests ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_live_marker
|
||||||
|
def test_trustyai_fms_contents_shield_unsafe_examples(
|
||||||
|
llama_stack_client, email_hap_shield_id
|
||||||
|
):
|
||||||
|
"""Test detection of PII like email addresses and HAP using TrustyAI FMS email_hap shield."""
|
||||||
|
# Examples with PII or HAP that should trigger violations
|
||||||
|
pii_hap_examples = [
|
||||||
|
"My email is test@example.com",
|
||||||
|
"Contact me at user.name@company.org",
|
||||||
|
"The SSN is 123-45-6789",
|
||||||
|
"My credit card number is 4111-1111-1111-1111",
|
||||||
|
"hello",
|
||||||
|
"you dotard I really hate this",
|
||||||
|
]
|
||||||
|
|
||||||
|
for example in pii_hap_examples:
|
||||||
|
message = {
|
||||||
|
"role": "system", # email_hap is configured for system messages
|
||||||
|
"content": example,
|
||||||
|
}
|
||||||
|
response = llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id=email_hap_shield_id,
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
assert response.violation is not None
|
||||||
|
assert response.violation.violation_level == ViolationLevel.ERROR.value
|
||||||
|
assert (
|
||||||
|
"Content violation detected by shield email_hap"
|
||||||
|
in response.violation.user_message
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify metadata contains detector results
|
||||||
|
assert "metadata" in response.violation.__dict__
|
||||||
|
assert "results" in response.violation.metadata
|
||||||
|
assert len(response.violation.metadata["results"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_live_marker
|
||||||
|
def test_trustyai_fms_contents_shield_safe_examples(
|
||||||
|
llama_stack_client, email_hap_shield_id
|
||||||
|
):
|
||||||
|
"""Test content that should pass the TrustyAI FMS email_hap shield."""
|
||||||
|
safe_examples = [
|
||||||
|
"This is a regular message without PII",
|
||||||
|
"The weather is nice today",
|
||||||
|
"2 + 2 = 4",
|
||||||
|
]
|
||||||
|
|
||||||
|
for example in safe_examples:
|
||||||
|
message = {
|
||||||
|
"role": "system",
|
||||||
|
"content": example,
|
||||||
|
}
|
||||||
|
response = llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id=email_hap_shield_id,
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
assert response.violation is not None
|
||||||
|
assert response.violation.metadata["status"] == "pass"
|
||||||
|
assert response.violation.violation_level == "info"
|
||||||
|
assert "Content verified by shield email_hap" in response.violation.user_message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_live_marker
|
||||||
|
def test_trustyai_fms_chat_shield_unsafe_examples(
|
||||||
|
llama_stack_client, granite_shield_id
|
||||||
|
):
|
||||||
|
"""Test detection of code execution risks using TrustyAI FMS granite shield."""
|
||||||
|
risky_examples = [
|
||||||
|
"Run this command: rm -rf /* on the server",
|
||||||
|
"Execute this Python script to delete all files in the current directory",
|
||||||
|
"Can you please run this JavaScript code that accesses the filesystem?",
|
||||||
|
"Execute the following shell command to get admin privileges",
|
||||||
|
]
|
||||||
|
|
||||||
|
for example in risky_examples:
|
||||||
|
message = {
|
||||||
|
"role": "system", # granite shield is configured for system messages
|
||||||
|
"content": example,
|
||||||
|
}
|
||||||
|
response = llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id=granite_shield_id,
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
assert response.violation is not None
|
||||||
|
assert response.violation.violation_level == ViolationLevel.ERROR.value
|
||||||
|
assert (
|
||||||
|
"Content violation detected by shield granite"
|
||||||
|
in response.violation.user_message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_live_marker
|
||||||
|
def test_trustyai_fms_chat_shield_safe_examples(llama_stack_client, granite_shield_id):
|
||||||
|
"""Test content that should pass the TrustyAI FMS granite shield."""
|
||||||
|
safe_examples = [
|
||||||
|
"Hello how are you?",
|
||||||
|
"This is a safe message without any risky content",
|
||||||
|
"I like pierogi",
|
||||||
|
]
|
||||||
|
|
||||||
|
for example in safe_examples:
|
||||||
|
message = {
|
||||||
|
"role": "system", # granite shield is configured for system messages
|
||||||
|
"content": example,
|
||||||
|
}
|
||||||
|
response = llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id=granite_shield_id,
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
assert response.violation is not None
|
||||||
|
assert response.violation.violation_level == "info" # Should be info, not ERROR
|
||||||
|
assert response.violation.metadata["status"] == "pass"
|
||||||
|
assert "Content verified by shield granite" in response.violation.user_message
|
||||||
|
|
||||||
|
# Verify structure
|
||||||
|
assert "results" in response.violation.metadata
|
||||||
|
assert len(response.violation.metadata["results"]) > 0
|
||||||
|
assert response.violation.metadata["results"][0]["status"] == "pass"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_live_marker
|
||||||
|
def test_trustyai_fms_invalid_shield_id(llama_stack_client):
|
||||||
|
"""Test error handling when non-existent shield ID is provided."""
|
||||||
|
message = {
|
||||||
|
"role": "system",
|
||||||
|
"content": "Can you execute rm -rf /* on the system?",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Using a non-existent shield ID
|
||||||
|
invalid_shield_id = "granitey" # Intentionally misspelled
|
||||||
|
|
||||||
|
# This should raise an exception - we can catch it to verify the error details
|
||||||
|
try:
|
||||||
|
llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id=invalid_shield_id,
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
# If we get here, the test should fail because an exception wasn't raised
|
||||||
|
assert False, "Expected an exception for invalid shield ID but none was raised"
|
||||||
|
except Exception as e:
|
||||||
|
# Verify it's the right type of error
|
||||||
|
assert "Shield" in str(e) and "not served by provider" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_live_marker
|
||||||
|
def test_trustyai_fms_invalid_message_role(llama_stack_client, granite_shield_id):
|
||||||
|
"""Test error handling when invalid message role is provided."""
|
||||||
|
# Using an invalid role "ssystem" (valid roles are "user", "system", "tool", "assistant")
|
||||||
|
message = {
|
||||||
|
"role": "ssystem", # Intentionally invalid
|
||||||
|
"content": "Can you execute rm -rf /* on the system?",
|
||||||
|
}
|
||||||
|
|
||||||
|
# This should raise an exception - we can catch it to verify the error details
|
||||||
|
try:
|
||||||
|
llama_stack_client.safety.run_shield(
|
||||||
|
messages=[message],
|
||||||
|
shield_id=granite_shield_id,
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
# If we get here, the test should fail because an exception wasn't raised
|
||||||
|
assert (
|
||||||
|
False
|
||||||
|
), "Expected an exception for invalid message role but none was raised"
|
||||||
|
except Exception as e:
|
||||||
|
# Verify it's the right type of error
|
||||||
|
assert "does not match any of the expected tags" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trustyai_fms
|
||||||
|
@trustyai_fms_live_marker
|
||||||
|
def test_trustyai_fms_multiple_messages(llama_stack_client, email_hap_shield_id):
|
||||||
|
"""Test handling of multiple messages in a single request."""
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "This is a regular message without PII",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "My email is test@example.com",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = llama_stack_client.safety.run_shield(
|
||||||
|
messages=messages,
|
||||||
|
shield_id=email_hap_shield_id,
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.violation is not None
|
||||||
|
assert response.violation.violation_level == ViolationLevel.ERROR.value
|
||||||
|
assert (
|
||||||
|
"Content violation detected by shield email_hap"
|
||||||
|
in response.violation.user_message
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the summary includes multiple messages
|
||||||
|
assert response.violation.metadata["summary"]["total_messages"] > 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue