diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 0669b65bb..8b35bd9d3 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -6,9 +6,10 @@ import logging import uuid -from typing import Any +from typing import TYPE_CHECKING, Any -from codeshield.cs import CodeShield, CodeShieldScanResult +if TYPE_CHECKING: + from codeshield.cs import CodeShieldScanResult from llama_stack.apis.inference import Message from llama_stack.apis.safety import ( @@ -59,6 +60,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): if not shield: raise ValueError(f"Shield {shield_id} not found") + from codeshield.cs import CodeShield + text = "\n".join([interleaved_content_as_str(m.content) for m in messages]) log.info(f"Running CodeScannerShield on {text[50:]}") result = await CodeShield.scan_code(text) @@ -72,7 +75,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): ) return RunShieldResponse(violation=violation) - def get_moderation_object_results(self, scan_result: CodeShieldScanResult) -> ModerationObjectResults: + def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults: categories = {} category_scores = {} category_applied_input_types = {} @@ -102,6 +105,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): inputs = input if isinstance(input, list) else [input] results = [] + from codeshield.cs import CodeShield + for text_input in inputs: log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...") scan_result = await CodeShield.scan_code(text_input) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index a1133f932..5d52c5d89 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -455,7 +455,7 @@ class LlamaGuardShield: def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool: """Check if content is safe based on response and unsafe code.""" - if response.strip() == SAFE_RESPONSE: + if response.strip().lower().startswith(SAFE_RESPONSE): return True if unsafe_code: