From 628e41193d443cf776f73b63ccb03bd643c2cda5 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Tue, 25 Mar 2025 18:55:24 +0000 Subject: [PATCH] :construction: ruff check and fixes to test_safety.py --- tests/integration/safety/test_safety.py | 73 ++++++++----------------- 1 file changed, 23 insertions(+), 50 deletions(-) diff --git a/tests/integration/safety/test_safety.py b/tests/integration/safety/test_safety.py index 11af0c71f..671af6e38 100644 --- a/tests/integration/safety/test_safety.py +++ b/tests/integration/safety/test_safety.py @@ -6,10 +6,11 @@ import base64 import mimetypes import os +from unittest.mock import MagicMock + import pytest from llama_stack.apis.safety import ViolationLevel -from unittest.mock import MagicMock CODE_SCANNER_ENABLED_PROVIDERS = {"ollama", "together", "fireworks"} @@ -258,9 +259,7 @@ def mock_llama_stack_client(): if has_pii: response.violation.violation_level = "error" - response.violation.user_message = ( - "Content violation detected by shield email_hap (confidence: 1.00)" - ) + response.violation.user_message = "Content violation detected by shield email_hap (confidence: 1.00)" response.violation.metadata = { "status": "violation", "shield_id": "email_hap", @@ -269,7 +268,9 @@ def mock_llama_stack_client(): } else: response.violation.violation_level = "info" - response.violation.user_message = f"Content verified by shield email_hap ({len(messages)} messages processed)" + response.violation.user_message = ( + f"Content verified by shield email_hap ({len(messages)} messages processed)" + ) response.violation.metadata = { "status": "pass", "shield_id": "email_hap", @@ -280,16 +281,11 @@ def mock_llama_stack_client(): # Handle granite shield (code execution) elif shield_id == "granite": risky_terms = ["rm -rf", "execute", "run this", "shell command"] - has_risk = any( - any(term in msg.get("content", "").lower() for term in risky_terms) - for msg in messages - ) + has_risk = any(any(term in msg.get("content", "").lower() for term in risky_terms) for msg in messages) if has_risk: response.violation.violation_level = "error" - response.violation.user_message = ( - "Content violation detected by shield granite (confidence: 1.00)" - ) + response.violation.user_message = "Content violation detected by shield granite (confidence: 1.00)" response.violation.metadata = { "status": "violation", "shield_id": "granite", @@ -298,7 +294,9 @@ def mock_llama_stack_client(): } else: response.violation.violation_level = "info" - response.violation.user_message = f"Content verified by shield granite ({len(messages)} messages processed)" + response.violation.user_message = ( + f"Content verified by shield granite ({len(messages)} messages processed)" + ) response.violation.metadata = { "status": "pass", "shield_id": "granite", @@ -383,10 +381,7 @@ def test_trustyai_fms_chat_shield_unsafe_examples_mock(mock_llama_stack_client): params={}, ) assert response.violation.violation_level == "error" - assert ( - "Content violation detected by shield granite" - in response.violation.user_message - ) + assert "Content violation detected by shield granite" in response.violation.user_message assert response.violation.metadata["status"] == "violation" @@ -428,7 +423,7 @@ def test_trustyai_fms_invalid_shield_id_mock(mock_llama_stack_client): shield_id=invalid_shield_id, params={}, ) - assert False, "Expected an exception for invalid shield ID but none was raised" + raise AssertionError("Expected an exception for invalid shield ID but none was raised") except Exception as e: assert "Shield" in str(e) and "not served by provider" in str(e) assert "trustyai_fms" in str(e) @@ -446,9 +441,7 @@ def test_trustyai_fms_invalid_message_role_mock(mock_llama_stack_client): shield_id="granite", params={}, ) - assert ( - False - ), "Expected an exception for invalid message role but none was raised" + raise AssertionError("Expected an exception for invalid message role but none was raised") except Exception as e: assert "does not match any of the expected tags" in str(e) assert "user" in str(e) and "system" in str(e) and "assistant" in str(e) @@ -470,10 +463,7 @@ def test_trustyai_fms_multiple_messages_mock(mock_llama_stack_client): ) assert response.violation.violation_level == "error" - assert ( - "Content violation detected by shield email_hap" - in response.violation.user_message - ) + assert "Content violation detected by shield email_hap" in response.violation.user_message assert response.violation.metadata["status"] == "violation" assert response.violation.metadata["summary"]["total_messages"] == 2 @@ -483,9 +473,7 @@ def test_trustyai_fms_multiple_messages_mock(mock_llama_stack_client): @pytest.mark.trustyai_fms @trustyai_fms_live_marker -def test_trustyai_fms_contents_shield_unsafe_examples( - llama_stack_client, email_hap_shield_id -): +def test_trustyai_fms_contents_shield_unsafe_examples(llama_stack_client, email_hap_shield_id): """Test detection of PII like email addresses and HAP using TrustyAI FMS email_hap shield.""" # Examples with PII or HAP that should trigger violations pii_hap_examples = [ @@ -509,10 +497,7 @@ def test_trustyai_fms_contents_shield_unsafe_examples( ) assert response.violation is not None assert response.violation.violation_level == ViolationLevel.ERROR.value - assert ( - "Content violation detected by shield email_hap" - in response.violation.user_message - ) + assert "Content violation detected by shield email_hap" in response.violation.user_message # Verify metadata contains detector results assert "metadata" in response.violation.__dict__ @@ -522,9 +507,7 @@ def test_trustyai_fms_contents_shield_unsafe_examples( @pytest.mark.trustyai_fms @trustyai_fms_live_marker -def test_trustyai_fms_contents_shield_safe_examples( - llama_stack_client, email_hap_shield_id -): +def test_trustyai_fms_contents_shield_safe_examples(llama_stack_client, email_hap_shield_id): """Test content that should pass the TrustyAI FMS email_hap shield.""" safe_examples = [ "This is a regular message without PII", @@ -550,9 +533,7 @@ def test_trustyai_fms_contents_shield_safe_examples( @pytest.mark.trustyai_fms @trustyai_fms_live_marker -def test_trustyai_fms_chat_shield_unsafe_examples( - llama_stack_client, granite_shield_id -): +def test_trustyai_fms_chat_shield_unsafe_examples(llama_stack_client, granite_shield_id): """Test detection of code execution risks using TrustyAI FMS granite shield.""" risky_examples = [ "Run this command: rm -rf /* on the server", @@ -573,10 +554,7 @@ def test_trustyai_fms_chat_shield_unsafe_examples( ) assert response.violation is not None assert response.violation.violation_level == ViolationLevel.ERROR.value - assert ( - "Content violation detected by shield granite" - in response.violation.user_message - ) + assert "Content violation detected by shield granite" in response.violation.user_message @pytest.mark.trustyai_fms @@ -630,7 +608,7 @@ def test_trustyai_fms_invalid_shield_id(llama_stack_client): params={}, ) # If we get here, the test should fail because an exception wasn't raised - assert False, "Expected an exception for invalid shield ID but none was raised" + raise AssertionError("Expected an exception for invalid shield ID but none was raised") except Exception as e: # Verify it's the right type of error assert "Shield" in str(e) and "not served by provider" in str(e) @@ -654,9 +632,7 @@ def test_trustyai_fms_invalid_message_role(llama_stack_client, granite_shield_id params={}, ) # If we get here, the test should fail because an exception wasn't raised - assert ( - False - ), "Expected an exception for invalid message role but none was raised" + raise AssertionError("Expected an exception for invalid message role but none was raised") except Exception as e: # Verify it's the right type of error assert "does not match any of the expected tags" in str(e) @@ -685,10 +661,7 @@ def test_trustyai_fms_multiple_messages(llama_stack_client, email_hap_shield_id) assert response.violation is not None assert response.violation.violation_level == ViolationLevel.ERROR.value - assert ( - "Content violation detected by shield email_hap" - in response.violation.user_message - ) + assert "Content violation detected by shield email_hap" in response.violation.user_message # Verify the summary includes multiple messages assert response.violation.metadata["summary"]["total_messages"] > 1