diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index c76673d2a..738ecded3 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -6,9 +6,7 @@ from typing import Any -from llama_stack.apis.inference import ( - Message, -) +from llama_stack.apis.inference import Message from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield @@ -68,6 +66,7 @@ class SafetyRouter(Safety): list_shields_response = await self.routing_table.list_shields() matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id] + if not matches: raise ValueError(f"No shield associated with provider_resource id {model}") if len(matches) > 1: diff --git a/llama_stack/distributions/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml index e6e699b62..900371d95 100644 --- a/llama_stack/distributions/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -28,6 +28,7 @@ distribution_spec: - provider_type: inline::localfs safety: - provider_type: inline::llama-guard + - provider_type: inline::code-scanner agents: - provider_type: inline::meta-reference telemetry: diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 05e1b4576..0da15857c 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -134,6 +134,8 @@ providers: provider_type: inline::llama-guard config: excluded_categories: [] + - provider_id: code-scanner + provider_type: inline::code-scanner agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -215,6 +217,9 @@ shields: - shield_id: llama-guard provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_shield_id: ${env.SAFETY_MODEL:=} +- shield_id: code-scanner + provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} + provider_shield_id: ${env.CODE_SCANNER_MODEL:=} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/distributions/starter/build.yaml b/llama_stack/distributions/starter/build.yaml index 1a4f81d49..b5f2b9694 100644 --- a/llama_stack/distributions/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -28,6 +28,7 @@ distribution_spec: - provider_type: inline::localfs safety: - provider_type: inline::llama-guard + - provider_type: inline::code-scanner agents: - provider_type: inline::meta-reference telemetry: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 46bd12956..d1771580b 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -134,6 +134,8 @@ providers: provider_type: inline::llama-guard config: excluded_categories: [] + - provider_id: code-scanner + provider_type: inline::code-scanner agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -215,6 +217,9 @@ shields: - shield_id: llama-guard provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_shield_id: ${env.SAFETY_MODEL:=} +- shield_id: code-scanner + provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} + provider_shield_id: ${env.CODE_SCANNER_MODEL:=} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index 0270b68ad..90fca335a 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -15,19 +15,14 @@ from llama_stack.core.datatypes import ( ToolGroupInput, ) from llama_stack.core.utils.dynamic import instantiate_class_type -from llama_stack.distributions.template import ( - DistributionTemplate, - RunConfigSettings, -) +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings from llama_stack.providers.datatypes import RemoteProviderSpec from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig -from llama_stack.providers.inline.vector_io.milvus.config import ( - MilvusVectorIOConfig, -) +from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.config import ( SQLiteVectorIOConfig, ) @@ -119,7 +114,10 @@ def get_distribution_template() -> DistributionTemplate: BuildProvider(provider_type="remote::pgvector"), ], "files": [BuildProvider(provider_type="inline::localfs")], - "safety": [BuildProvider(provider_type="inline::llama-guard")], + "safety": [ + BuildProvider(provider_type="inline::llama-guard"), + BuildProvider(provider_type="inline::code-scanner"), + ], "agents": [BuildProvider(provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")], "post_training": [BuildProvider(provider_type="inline::huggingface")], @@ -167,6 +165,11 @@ def get_distribution_template() -> DistributionTemplate: provider_id="${env.SAFETY_MODEL:+llama-guard}", provider_shield_id="${env.SAFETY_MODEL:=}", ), + ShieldInput( + shield_id="code-scanner", + provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}", + provider_shield_id="${env.CODE_SCANNER_MODEL:=}", + ), ] return DistributionTemplate( diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index be05ee436..6e05d5b83 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -5,7 +5,11 @@ # the root directory of this source tree. import logging -from typing import Any +import uuid +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from codeshield.cs import CodeShieldScanResult from llama_stack.apis.inference import Message from llama_stack.apis.safety import ( @@ -14,6 +18,7 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) +from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.shields import Shield from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, @@ -24,8 +29,8 @@ from .config import CodeScannerConfig log = logging.getLogger(__name__) ALLOWED_CODE_SCANNER_MODEL_IDS = [ - "CodeScanner", - "CodeShield", + "code-scanner", + "code-shield", ] @@ -69,3 +74,55 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])}, ) return RunShieldResponse(violation=violation) + + def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults: + categories = {} + category_scores = {} + category_applied_input_types = {} + + flagged = scan_result.is_insecure + user_message = None + metadata = {} + + if scan_result.is_insecure: + pattern_ids = [issue.pattern_id for issue in scan_result.issues_found] + categories = dict.fromkeys(pattern_ids, True) + category_scores = dict.fromkeys(pattern_ids, 1.0) + category_applied_input_types = {key: ["text"] for key in pattern_ids} + user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}" + metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])} + + return ModerationObjectResults( + flagged=flagged, + categories=categories, + category_scores=category_scores, + category_applied_input_types=category_applied_input_types, + user_message=user_message, + metadata=metadata, + ) + + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + inputs = input if isinstance(input, list) else [input] + results = [] + + from codeshield.cs import CodeShield + + for text_input in inputs: + log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...") + try: + scan_result = await CodeShield.scan_code(text_input) + moderation_result = self.get_moderation_object_results(scan_result) + except Exception as e: + log.error(f"CodeShield.scan_code failed: {e}") + # create safe fallback response on scanner failure to avoid blocking legitimate requests + moderation_result = ModerationObjectResults( + flagged=False, + categories={}, + category_scores={}, + category_applied_input_types={}, + user_message=None, + metadata={"scanner_error": str(e)}, + ) + results.append(moderation_result) + + return ModerationObject(id=str(uuid.uuid4()), model=model, results=results) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index bae744010..5d52c5d89 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -11,11 +11,7 @@ from string import Template from typing import Any from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem -from llama_stack.apis.inference import ( - Inference, - Message, - UserMessage, -) +from llama_stack.apis.inference import Inference, Message, UserMessage from llama_stack.apis.safety import ( RunShieldResponse, Safety, @@ -72,7 +68,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = { } SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()} - DEFAULT_LG_V3_SAFETY_CATEGORIES = [ CAT_VIOLENT_CRIMES, CAT_NON_VIOLENT_CRIMES, @@ -460,7 +455,7 @@ class LlamaGuardShield: def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool: """Check if content is safe based on response and unsafe code.""" - if response.strip() == SAFE_RESPONSE: + if response.strip().lower().startswith(SAFE_RESPONSE): return True if unsafe_code: diff --git a/tests/integration/safety/test_safety.py b/tests/integration/safety/test_safety.py index 75b974926..a21d9cbdf 100644 --- a/tests/integration/safety/test_safety.py +++ b/tests/integration/safety/test_safety.py @@ -27,10 +27,10 @@ def data_url_from_image(file_path): @pytest.fixture(scope="session") def code_scanner_shield_id(available_shields): - if "CodeScanner" in available_shields: - return "CodeScanner" + if "code-scanner" in available_shields: + return "code-scanner" - pytest.skip("CodeScanner shield is not available. Skipping.") + pytest.skip("code-scanner shield is not available. Skipping.") def test_unsafe_examples(client_with_models, shield_id): @@ -163,6 +163,60 @@ def test_safety_with_code_scanner(client_with_models, code_scanner_shield_id, mo assert response.violation.user_message == "Sorry, I found security concerns in the code." +def test_unsafe_input_with_code_scanner_run_moderation(client_with_models, code_scanner_shield_id, model_providers): + code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0 + if not code_safety_enabled: + pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}") + + insecure_code = [ + """ + Here is an example of code that hashes a password using MD5: + + ``` + import hashlib + password = 'securepassword123'.encode() + hashed_password = hashlib.md5(password).hexdigest() + print("Insecure hashed password:", hashed_password) + ``` + """ + ] + moderation_object = client_with_models.moderations.create( + input=insecure_code, + model=code_scanner_shield_id, + ) + assert moderation_object.results[0].flagged is True, f"Code scanner should have flagged {insecure_code} as insecure" + assert all(value is True for value in moderation_object.results[0].categories.values()), ( + "Code scanner shield should have detected code insecure category" + ) + + +def test_safe_input_with_code_scanner_run_moderation(client_with_models, code_scanner_shield_id, model_providers): + code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0 + if not code_safety_enabled: + pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}") + + secure_code = [ + """ + Extract the first 5 characters from a string: + ``` + text = "Hello World" + first_five = text[:5] + print(first_five) # Output: "Hello" + + # Safe handling for strings shorter than 5 characters + def get_first_five(text): + return text[:5] if text else "" + ``` + """ + ] + moderation_object = client_with_models.moderations.create( + input=secure_code, + model=code_scanner_shield_id, + ) + + assert moderation_object.results[0].flagged is False, "Code scanner should not have flagged the code as insecure" + + # We can use an instance of the LlamaGuard shield to detect attempts to misuse # the interpreter as this is one of the existing categories it checks for def test_safety_with_code_interpreter_abuse(client_with_models, shield_id):