🚧 ruff check and fixes to test_safety.py

This commit is contained in:
m-misiura 2025-03-25 18:55:24 +00:00
parent e36df10f6d
commit 628e41193d

View file

@ -6,10 +6,11 @@
import base64 import base64
import mimetypes import mimetypes
import os import os
from unittest.mock import MagicMock
import pytest import pytest
from llama_stack.apis.safety import ViolationLevel from llama_stack.apis.safety import ViolationLevel
from unittest.mock import MagicMock
CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"} CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"}
@ -258,9 +259,7 @@ def mock_llama_stack_client():
if has_pii: if has_pii:
response.violation.violation_level = "error" response.violation.violation_level = "error"
response.violation.user_message = ( response.violation.user_message = "Content violation detected by shield email_hap (confidence: 1.00)"
"Content violation detected by shield email_hap (confidence: 1.00)"
)
response.violation.metadata = { response.violation.metadata = {
"status": "violation", "status": "violation",
"shield_id": "email_hap", "shield_id": "email_hap",
@ -269,7 +268,9 @@ def mock_llama_stack_client():
} }
else: else:
response.violation.violation_level = "info" response.violation.violation_level = "info"
response.violation.user_message = f"Content verified by shield email_hap ({len(messages)} messages processed)" response.violation.user_message = (
f"Content verified by shield email_hap ({len(messages)} messages processed)"
)
response.violation.metadata = { response.violation.metadata = {
"status": "pass", "status": "pass",
"shield_id": "email_hap", "shield_id": "email_hap",
@ -280,16 +281,11 @@ def mock_llama_stack_client():
# Handle granite shield (code execution) # Handle granite shield (code execution)
elif shield_id == "granite": elif shield_id == "granite":
risky_terms = ["rm -rf", "execute", "run this", "shell command"] risky_terms = ["rm -rf", "execute", "run this", "shell command"]
has_risk = any( has_risk = any(any(term in msg.get("content", "").lower() for term in risky_terms) for msg in messages)
any(term in msg.get("content", "").lower() for term in risky_terms)
for msg in messages
)
if has_risk: if has_risk:
response.violation.violation_level = "error" response.violation.violation_level = "error"
response.violation.user_message = ( response.violation.user_message = "Content violation detected by shield granite (confidence: 1.00)"
"Content violation detected by shield granite (confidence: 1.00)"
)
response.violation.metadata = { response.violation.metadata = {
"status": "violation", "status": "violation",
"shield_id": "granite", "shield_id": "granite",
@ -298,7 +294,9 @@ def mock_llama_stack_client():
} }
else: else:
response.violation.violation_level = "info" response.violation.violation_level = "info"
response.violation.user_message = f"Content verified by shield granite ({len(messages)} messages processed)" response.violation.user_message = (
f"Content verified by shield granite ({len(messages)} messages processed)"
)
response.violation.metadata = { response.violation.metadata = {
"status": "pass", "status": "pass",
"shield_id": "granite", "shield_id": "granite",
@ -383,10 +381,7 @@ def test_trustyai_fms_chat_shield_unsafe_examples_mock(mock_llama_stack_client):
params={}, params={},
) )
assert response.violation.violation_level == "error" assert response.violation.violation_level == "error"
assert ( assert "Content violation detected by shield granite" in response.violation.user_message
"Content violation detected by shield granite"
in response.violation.user_message
)
assert response.violation.metadata["status"] == "violation" assert response.violation.metadata["status"] == "violation"
@ -428,7 +423,7 @@ def test_trustyai_fms_invalid_shield_id_mock(mock_llama_stack_client):
shield_id=invalid_shield_id, shield_id=invalid_shield_id,
params={}, params={},
) )
assert False, "Expected an exception for invalid shield ID but none was raised" raise AssertionError("Expected an exception for invalid shield ID but none was raised")
except Exception as e: except Exception as e:
assert "Shield" in str(e) and "not served by provider" in str(e) assert "Shield" in str(e) and "not served by provider" in str(e)
assert "trustyai_fms" in str(e) assert "trustyai_fms" in str(e)
@ -446,9 +441,7 @@ def test_trustyai_fms_invalid_message_role_mock(mock_llama_stack_client):
shield_id="granite", shield_id="granite",
params={}, params={},
) )
assert ( raise AssertionError("Expected an exception for invalid message role but none was raised")
False
), "Expected an exception for invalid message role but none was raised"
except Exception as e: except Exception as e:
assert "does not match any of the expected tags" in str(e) assert "does not match any of the expected tags" in str(e)
assert "user" in str(e) and "system" in str(e) and "assistant" in str(e) assert "user" in str(e) and "system" in str(e) and "assistant" in str(e)
@ -470,10 +463,7 @@ def test_trustyai_fms_multiple_messages_mock(mock_llama_stack_client):
) )
assert response.violation.violation_level == "error" assert response.violation.violation_level == "error"
assert ( assert "Content violation detected by shield email_hap" in response.violation.user_message
"Content violation detected by shield email_hap"
in response.violation.user_message
)
assert response.violation.metadata["status"] == "violation" assert response.violation.metadata["status"] == "violation"
assert response.violation.metadata["summary"]["total_messages"] == 2 assert response.violation.metadata["summary"]["total_messages"] == 2
@ -483,9 +473,7 @@ def test_trustyai_fms_multiple_messages_mock(mock_llama_stack_client):
@pytest.mark.trustyai_fms @pytest.mark.trustyai_fms
@trustyai_fms_live_marker @trustyai_fms_live_marker
def test_trustyai_fms_contents_shield_unsafe_examples( def test_trustyai_fms_contents_shield_unsafe_examples(llama_stack_client, email_hap_shield_id):
llama_stack_client, email_hap_shield_id
):
"""Test detection of PII like email addresses and HAP using TrustyAI FMS email_hap shield.""" """Test detection of PII like email addresses and HAP using TrustyAI FMS email_hap shield."""
# Examples with PII or HAP that should trigger violations # Examples with PII or HAP that should trigger violations
pii_hap_examples = [ pii_hap_examples = [
@ -509,10 +497,7 @@ def test_trustyai_fms_contents_shield_unsafe_examples(
) )
assert response.violation is not None assert response.violation is not None
assert response.violation.violation_level == ViolationLevel.ERROR.value assert response.violation.violation_level == ViolationLevel.ERROR.value
assert ( assert "Content violation detected by shield email_hap" in response.violation.user_message
"Content violation detected by shield email_hap"
in response.violation.user_message
)
# Verify metadata contains detector results # Verify metadata contains detector results
assert "metadata" in response.violation.__dict__ assert "metadata" in response.violation.__dict__
@ -522,9 +507,7 @@ def test_trustyai_fms_contents_shield_unsafe_examples(
@pytest.mark.trustyai_fms @pytest.mark.trustyai_fms
@trustyai_fms_live_marker @trustyai_fms_live_marker
def test_trustyai_fms_contents_shield_safe_examples( def test_trustyai_fms_contents_shield_safe_examples(llama_stack_client, email_hap_shield_id):
llama_stack_client, email_hap_shield_id
):
"""Test content that should pass the TrustyAI FMS email_hap shield.""" """Test content that should pass the TrustyAI FMS email_hap shield."""
safe_examples = [ safe_examples = [
"This is a regular message without PII", "This is a regular message without PII",
@ -550,9 +533,7 @@ def test_trustyai_fms_contents_shield_safe_examples(
@pytest.mark.trustyai_fms @pytest.mark.trustyai_fms
@trustyai_fms_live_marker @trustyai_fms_live_marker
def test_trustyai_fms_chat_shield_unsafe_examples( def test_trustyai_fms_chat_shield_unsafe_examples(llama_stack_client, granite_shield_id):
llama_stack_client, granite_shield_id
):
"""Test detection of code execution risks using TrustyAI FMS granite shield.""" """Test detection of code execution risks using TrustyAI FMS granite shield."""
risky_examples = [ risky_examples = [
"Run this command: rm -rf /* on the server", "Run this command: rm -rf /* on the server",
@ -573,10 +554,7 @@ def test_trustyai_fms_chat_shield_unsafe_examples(
) )
assert response.violation is not None assert response.violation is not None
assert response.violation.violation_level == ViolationLevel.ERROR.value assert response.violation.violation_level == ViolationLevel.ERROR.value
assert ( assert "Content violation detected by shield granite" in response.violation.user_message
"Content violation detected by shield granite"
in response.violation.user_message
)
@pytest.mark.trustyai_fms @pytest.mark.trustyai_fms
@ -630,7 +608,7 @@ def test_trustyai_fms_invalid_shield_id(llama_stack_client):
params={}, params={},
) )
# If we get here, the test should fail because an exception wasn't raised # If we get here, the test should fail because an exception wasn't raised
assert False, "Expected an exception for invalid shield ID but none was raised" raise AssertionError("Expected an exception for invalid shield ID but none was raised")
except Exception as e: except Exception as e:
# Verify it's the right type of error # Verify it's the right type of error
assert "Shield" in str(e) and "not served by provider" in str(e) assert "Shield" in str(e) and "not served by provider" in str(e)
@ -654,9 +632,7 @@ def test_trustyai_fms_invalid_message_role(llama_stack_client, granite_shield_id
params={}, params={},
) )
# If we get here, the test should fail because an exception wasn't raised # If we get here, the test should fail because an exception wasn't raised
assert ( raise AssertionError("Expected an exception for invalid message role but none was raised")
False
), "Expected an exception for invalid message role but none was raised"
except Exception as e: except Exception as e:
# Verify it's the right type of error # Verify it's the right type of error
assert "does not match any of the expected tags" in str(e) assert "does not match any of the expected tags" in str(e)
@ -685,10 +661,7 @@ def test_trustyai_fms_multiple_messages(llama_stack_client, email_hap_shield_id)
assert response.violation is not None assert response.violation is not None
assert response.violation.violation_level == ViolationLevel.ERROR.value assert response.violation.violation_level == ViolationLevel.ERROR.value
assert ( assert "Content violation detected by shield email_hap" in response.violation.user_message
"Content violation detected by shield email_hap"
in response.violation.user_message
)
# Verify the summary includes multiple messages # Verify the summary includes multiple messages
assert response.violation.metadata["summary"]["total_messages"] > 1 assert response.violation.metadata["summary"]["total_messages"] > 1