diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 46bd12956..493977ded 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -130,8 +130,8 @@ providers: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_id: CodeScanner + provider_type: inline::code-scanner config: excluded_categories: [] agents: @@ -212,8 +212,8 @@ inference_store: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db models: [] shields: -- shield_id: llama-guard - provider_id: ${env.SAFETY_MODEL:+llama-guard} +- shield_id: CodeScanner + provider_id: ${env.SAFETY_MODEL:+CodeScanner} provider_shield_id: ${env.SAFETY_MODEL:=} vector_dbs: [] datasets: [] diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index be05ee436..6e1e10b18 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -5,8 +5,11 @@ # the root directory of this source tree. import logging +import uuid from typing import Any +from codeshield.cs import CodeShield, CodeShieldScanResult + from llama_stack.apis.inference import Message from llama_stack.apis.safety import ( RunShieldResponse, @@ -14,6 +17,7 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) +from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.shields import Shield from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, @@ -55,8 +59,6 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): if not shield: raise ValueError(f"Shield {shield_id} not found") - from codeshield.cs import CodeShield - text = "\n".join([interleaved_content_as_str(m.content) for m in messages]) log.info(f"Running CodeScannerShield on {text[50:]}") result = await CodeShield.scan_code(text) @@ -69,3 +71,42 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])}, ) return RunShieldResponse(violation=violation) + + def get_moderation_object_results(self, scan_result: CodeShieldScanResult) -> ModerationObjectResults: + categories = {} + category_scores = {} + category_applied_input_types = {} + + flagged = scan_result.is_insecure + user_message = None + metadata = {} + # TODO check both list of inputs and single input; add some unit tests + + if scan_result.is_insecure: + pattern_ids = [issue.pattern_id for issue in scan_result.issues_found] + categories = dict.fromkeys(pattern_ids, True) + category_scores = dict.fromkeys(pattern_ids, 1.0) + category_applied_input_types = {key: ["text"] for key in pattern_ids} + user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}" + metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])} + + return ModerationObjectResults( + flagged=flagged, + categories=categories, + category_scores=category_scores, + category_applied_input_types=category_applied_input_types, + user_message=user_message, + metadata=metadata, + ) + + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + inputs = input if isinstance(input, list) else [input] + results = [] + + for text_input in inputs: + log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...") + scan_result = await CodeShield.scan_code(text_input) + moderation_result = self.get_moderation_object_results(scan_result) + results.append(moderation_result) + + return ModerationObject(id=str(uuid.uuid4()), model=model, results=results) diff --git a/tests/integration/safety/test_safety.py b/tests/integration/safety/test_safety.py index 75b974926..7af0a2827 100644 --- a/tests/integration/safety/test_safety.py +++ b/tests/integration/safety/test_safety.py @@ -163,6 +163,58 @@ def test_safety_with_code_scanner(client_with_models, code_scanner_shield_id, mo assert response.violation.user_message == "Sorry, I found security concerns in the code." +# The code scanning shield uses Meta's Code Shield library to detect violations +# in which an LLM generates insecure code. Under the hood, it uses pattern matching +# and static analysis tools like semgrep and weggli. +def test_unsafe_input_with_code_scanner_moderations(client_with_models, code_scanner_shield_id, model_providers): + code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0 + if not code_safety_enabled: + pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}") + + insecure_code = """ + Here is an example of code that hashes a password using MD5: + + ``` + import hashlib + password = 'securepassword123'.encode() + hashed_password = hashlib.md5(password).hexdigest() + print("Insecure hashed password:", hashed_password) + ``` + """ + moderation_object = client_with_models.moderations.create( + input=insecure_code, + model=code_scanner_shield_id, + ) + assert moderation_object.results[0].flagged is True, "Code scanner should have flagged the code as insecure" + assert all(value is True for value in moderation_object.results[0].categories.values()), ( + "Code scanner shield should have detected code insecure category" + ) + + +def test_safe_input_with_code_scanner_moderations_api(client_with_models, code_scanner_shield_id, model_providers): + code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0 + if not code_safety_enabled: + pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}") + + secure_code = """ + Extract the first 5 characters from a string: + ``` + text = "Hello World" + first_five = text[:5] + print(first_five) # Output: "Hello" + + # Safe handling for strings shorter than 5 characters + def get_first_five(text): + return text[:5] if text else "" + ``` + """ + moderation_object = client_with_models.moderations.create( + input=secure_code, + model=code_scanner_shield_id, + ) + assert moderation_object.results[0].flagged is False, "Code scanner should not have flagged the code as insecure" + + # We can use an instance of the LlamaGuard shield to detect attempts to misuse # the interpreter as this is one of the existing categories it checks for def test_safety_with_code_interpreter_abuse(client_with_models, shield_id):