mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
🚧 ruff check and fixes to test_safety.py
This commit is contained in:
parent
e36df10f6d
commit
628e41193d
1 changed files with 23 additions and 50 deletions
|
@ -6,10 +6,11 @@
|
||||||
import base64
|
import base64
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.safety import ViolationLevel
|
from llama_stack.apis.safety import ViolationLevel
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"}
|
CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"}
|
||||||
|
|
||||||
|
@ -258,9 +259,7 @@ def mock_llama_stack_client():
|
||||||
|
|
||||||
if has_pii:
|
if has_pii:
|
||||||
response.violation.violation_level = "error"
|
response.violation.violation_level = "error"
|
||||||
response.violation.user_message = (
|
response.violation.user_message = "Content violation detected by shield email_hap (confidence: 1.00)"
|
||||||
"Content violation detected by shield email_hap (confidence: 1.00)"
|
|
||||||
)
|
|
||||||
response.violation.metadata = {
|
response.violation.metadata = {
|
||||||
"status": "violation",
|
"status": "violation",
|
||||||
"shield_id": "email_hap",
|
"shield_id": "email_hap",
|
||||||
|
@ -269,7 +268,9 @@ def mock_llama_stack_client():
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
response.violation.violation_level = "info"
|
response.violation.violation_level = "info"
|
||||||
response.violation.user_message = f"Content verified by shield email_hap ({len(messages)} messages processed)"
|
response.violation.user_message = (
|
||||||
|
f"Content verified by shield email_hap ({len(messages)} messages processed)"
|
||||||
|
)
|
||||||
response.violation.metadata = {
|
response.violation.metadata = {
|
||||||
"status": "pass",
|
"status": "pass",
|
||||||
"shield_id": "email_hap",
|
"shield_id": "email_hap",
|
||||||
|
@ -280,16 +281,11 @@ def mock_llama_stack_client():
|
||||||
# Handle granite shield (code execution)
|
# Handle granite shield (code execution)
|
||||||
elif shield_id == "granite":
|
elif shield_id == "granite":
|
||||||
risky_terms = ["rm -rf", "execute", "run this", "shell command"]
|
risky_terms = ["rm -rf", "execute", "run this", "shell command"]
|
||||||
has_risk = any(
|
has_risk = any(any(term in msg.get("content", "").lower() for term in risky_terms) for msg in messages)
|
||||||
any(term in msg.get("content", "").lower() for term in risky_terms)
|
|
||||||
for msg in messages
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_risk:
|
if has_risk:
|
||||||
response.violation.violation_level = "error"
|
response.violation.violation_level = "error"
|
||||||
response.violation.user_message = (
|
response.violation.user_message = "Content violation detected by shield granite (confidence: 1.00)"
|
||||||
"Content violation detected by shield granite (confidence: 1.00)"
|
|
||||||
)
|
|
||||||
response.violation.metadata = {
|
response.violation.metadata = {
|
||||||
"status": "violation",
|
"status": "violation",
|
||||||
"shield_id": "granite",
|
"shield_id": "granite",
|
||||||
|
@ -298,7 +294,9 @@ def mock_llama_stack_client():
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
response.violation.violation_level = "info"
|
response.violation.violation_level = "info"
|
||||||
response.violation.user_message = f"Content verified by shield granite ({len(messages)} messages processed)"
|
response.violation.user_message = (
|
||||||
|
f"Content verified by shield granite ({len(messages)} messages processed)"
|
||||||
|
)
|
||||||
response.violation.metadata = {
|
response.violation.metadata = {
|
||||||
"status": "pass",
|
"status": "pass",
|
||||||
"shield_id": "granite",
|
"shield_id": "granite",
|
||||||
|
@ -383,10 +381,7 @@ def test_trustyai_fms_chat_shield_unsafe_examples_mock(mock_llama_stack_client):
|
||||||
params={},
|
params={},
|
||||||
)
|
)
|
||||||
assert response.violation.violation_level == "error"
|
assert response.violation.violation_level == "error"
|
||||||
assert (
|
assert "Content violation detected by shield granite" in response.violation.user_message
|
||||||
"Content violation detected by shield granite"
|
|
||||||
in response.violation.user_message
|
|
||||||
)
|
|
||||||
assert response.violation.metadata["status"] == "violation"
|
assert response.violation.metadata["status"] == "violation"
|
||||||
|
|
||||||
|
|
||||||
|
@ -428,7 +423,7 @@ def test_trustyai_fms_invalid_shield_id_mock(mock_llama_stack_client):
|
||||||
shield_id=invalid_shield_id,
|
shield_id=invalid_shield_id,
|
||||||
params={},
|
params={},
|
||||||
)
|
)
|
||||||
assert False, "Expected an exception for invalid shield ID but none was raised"
|
raise AssertionError("Expected an exception for invalid shield ID but none was raised")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert "Shield" in str(e) and "not served by provider" in str(e)
|
assert "Shield" in str(e) and "not served by provider" in str(e)
|
||||||
assert "trustyai_fms" in str(e)
|
assert "trustyai_fms" in str(e)
|
||||||
|
@ -446,9 +441,7 @@ def test_trustyai_fms_invalid_message_role_mock(mock_llama_stack_client):
|
||||||
shield_id="granite",
|
shield_id="granite",
|
||||||
params={},
|
params={},
|
||||||
)
|
)
|
||||||
assert (
|
raise AssertionError("Expected an exception for invalid message role but none was raised")
|
||||||
False
|
|
||||||
), "Expected an exception for invalid message role but none was raised"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert "does not match any of the expected tags" in str(e)
|
assert "does not match any of the expected tags" in str(e)
|
||||||
assert "user" in str(e) and "system" in str(e) and "assistant" in str(e)
|
assert "user" in str(e) and "system" in str(e) and "assistant" in str(e)
|
||||||
|
@ -470,10 +463,7 @@ def test_trustyai_fms_multiple_messages_mock(mock_llama_stack_client):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.violation.violation_level == "error"
|
assert response.violation.violation_level == "error"
|
||||||
assert (
|
assert "Content violation detected by shield email_hap" in response.violation.user_message
|
||||||
"Content violation detected by shield email_hap"
|
|
||||||
in response.violation.user_message
|
|
||||||
)
|
|
||||||
assert response.violation.metadata["status"] == "violation"
|
assert response.violation.metadata["status"] == "violation"
|
||||||
assert response.violation.metadata["summary"]["total_messages"] == 2
|
assert response.violation.metadata["summary"]["total_messages"] == 2
|
||||||
|
|
||||||
|
@ -483,9 +473,7 @@ def test_trustyai_fms_multiple_messages_mock(mock_llama_stack_client):
|
||||||
|
|
||||||
@pytest.mark.trustyai_fms
|
@pytest.mark.trustyai_fms
|
||||||
@trustyai_fms_live_marker
|
@trustyai_fms_live_marker
|
||||||
def test_trustyai_fms_contents_shield_unsafe_examples(
|
def test_trustyai_fms_contents_shield_unsafe_examples(llama_stack_client, email_hap_shield_id):
|
||||||
llama_stack_client, email_hap_shield_id
|
|
||||||
):
|
|
||||||
"""Test detection of PII like email addresses and HAP using TrustyAI FMS email_hap shield."""
|
"""Test detection of PII like email addresses and HAP using TrustyAI FMS email_hap shield."""
|
||||||
# Examples with PII or HAP that should trigger violations
|
# Examples with PII or HAP that should trigger violations
|
||||||
pii_hap_examples = [
|
pii_hap_examples = [
|
||||||
|
@ -509,10 +497,7 @@ def test_trustyai_fms_contents_shield_unsafe_examples(
|
||||||
)
|
)
|
||||||
assert response.violation is not None
|
assert response.violation is not None
|
||||||
assert response.violation.violation_level == ViolationLevel.ERROR.value
|
assert response.violation.violation_level == ViolationLevel.ERROR.value
|
||||||
assert (
|
assert "Content violation detected by shield email_hap" in response.violation.user_message
|
||||||
"Content violation detected by shield email_hap"
|
|
||||||
in response.violation.user_message
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify metadata contains detector results
|
# Verify metadata contains detector results
|
||||||
assert "metadata" in response.violation.__dict__
|
assert "metadata" in response.violation.__dict__
|
||||||
|
@ -522,9 +507,7 @@ def test_trustyai_fms_contents_shield_unsafe_examples(
|
||||||
|
|
||||||
@pytest.mark.trustyai_fms
|
@pytest.mark.trustyai_fms
|
||||||
@trustyai_fms_live_marker
|
@trustyai_fms_live_marker
|
||||||
def test_trustyai_fms_contents_shield_safe_examples(
|
def test_trustyai_fms_contents_shield_safe_examples(llama_stack_client, email_hap_shield_id):
|
||||||
llama_stack_client, email_hap_shield_id
|
|
||||||
):
|
|
||||||
"""Test content that should pass the TrustyAI FMS email_hap shield."""
|
"""Test content that should pass the TrustyAI FMS email_hap shield."""
|
||||||
safe_examples = [
|
safe_examples = [
|
||||||
"This is a regular message without PII",
|
"This is a regular message without PII",
|
||||||
|
@ -550,9 +533,7 @@ def test_trustyai_fms_contents_shield_safe_examples(
|
||||||
|
|
||||||
@pytest.mark.trustyai_fms
|
@pytest.mark.trustyai_fms
|
||||||
@trustyai_fms_live_marker
|
@trustyai_fms_live_marker
|
||||||
def test_trustyai_fms_chat_shield_unsafe_examples(
|
def test_trustyai_fms_chat_shield_unsafe_examples(llama_stack_client, granite_shield_id):
|
||||||
llama_stack_client, granite_shield_id
|
|
||||||
):
|
|
||||||
"""Test detection of code execution risks using TrustyAI FMS granite shield."""
|
"""Test detection of code execution risks using TrustyAI FMS granite shield."""
|
||||||
risky_examples = [
|
risky_examples = [
|
||||||
"Run this command: rm -rf /* on the server",
|
"Run this command: rm -rf /* on the server",
|
||||||
|
@ -573,10 +554,7 @@ def test_trustyai_fms_chat_shield_unsafe_examples(
|
||||||
)
|
)
|
||||||
assert response.violation is not None
|
assert response.violation is not None
|
||||||
assert response.violation.violation_level == ViolationLevel.ERROR.value
|
assert response.violation.violation_level == ViolationLevel.ERROR.value
|
||||||
assert (
|
assert "Content violation detected by shield granite" in response.violation.user_message
|
||||||
"Content violation detected by shield granite"
|
|
||||||
in response.violation.user_message
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trustyai_fms
|
@pytest.mark.trustyai_fms
|
||||||
|
@ -630,7 +608,7 @@ def test_trustyai_fms_invalid_shield_id(llama_stack_client):
|
||||||
params={},
|
params={},
|
||||||
)
|
)
|
||||||
# If we get here, the test should fail because an exception wasn't raised
|
# If we get here, the test should fail because an exception wasn't raised
|
||||||
assert False, "Expected an exception for invalid shield ID but none was raised"
|
raise AssertionError("Expected an exception for invalid shield ID but none was raised")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Verify it's the right type of error
|
# Verify it's the right type of error
|
||||||
assert "Shield" in str(e) and "not served by provider" in str(e)
|
assert "Shield" in str(e) and "not served by provider" in str(e)
|
||||||
|
@ -654,9 +632,7 @@ def test_trustyai_fms_invalid_message_role(llama_stack_client, granite_shield_id
|
||||||
params={},
|
params={},
|
||||||
)
|
)
|
||||||
# If we get here, the test should fail because an exception wasn't raised
|
# If we get here, the test should fail because an exception wasn't raised
|
||||||
assert (
|
raise AssertionError("Expected an exception for invalid message role but none was raised")
|
||||||
False
|
|
||||||
), "Expected an exception for invalid message role but none was raised"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Verify it's the right type of error
|
# Verify it's the right type of error
|
||||||
assert "does not match any of the expected tags" in str(e)
|
assert "does not match any of the expected tags" in str(e)
|
||||||
|
@ -685,10 +661,7 @@ def test_trustyai_fms_multiple_messages(llama_stack_client, email_hap_shield_id)
|
||||||
|
|
||||||
assert response.violation is not None
|
assert response.violation is not None
|
||||||
assert response.violation.violation_level == ViolationLevel.ERROR.value
|
assert response.violation.violation_level == ViolationLevel.ERROR.value
|
||||||
assert (
|
assert "Content violation detected by shield email_hap" in response.violation.user_message
|
||||||
"Content violation detected by shield email_hap"
|
|
||||||
in response.violation.user_message
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the summary includes multiple messages
|
# Verify the summary includes multiple messages
|
||||||
assert response.violation.metadata["summary"]["total_messages"] > 1
|
assert response.violation.metadata["summary"]["total_messages"] > 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue