feat: Moderation api for Code Scanner Provider

This commit is contained in:
Swapna Lekkala 2025-08-11 13:45:36 -07:00
parent 61582f327c
commit ef15c74307
3 changed files with 99 additions and 6 deletions

View file

@ -130,8 +130,8 @@ providers:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety: safety:
- provider_id: llama-guard - provider_id: CodeScanner
provider_type: inline::llama-guard provider_type: inline::code-scanner
config: config:
excluded_categories: [] excluded_categories: []
agents: agents:
@ -212,8 +212,8 @@ inference_store:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db
models: [] models: []
shields: shields:
- shield_id: llama-guard - shield_id: CodeScanner
provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_id: ${env.SAFETY_MODEL:+CodeScanner}
provider_shield_id: ${env.SAFETY_MODEL:=} provider_shield_id: ${env.SAFETY_MODEL:=}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []

View file

@ -5,8 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
import uuid
from typing import Any from typing import Any
from codeshield.cs import CodeShield, CodeShieldScanResult
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import ( from llama_stack.apis.safety import (
RunShieldResponse, RunShieldResponse,
@ -14,6 +17,7 @@ from llama_stack.apis.safety import (
SafetyViolation, SafetyViolation,
ViolationLevel, ViolationLevel,
) )
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
@ -55,8 +59,6 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
if not shield: if not shield:
raise ValueError(f"Shield {shield_id} not found") raise ValueError(f"Shield {shield_id} not found")
from codeshield.cs import CodeShield
text = "\n".join([interleaved_content_as_str(m.content) for m in messages]) text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
log.info(f"Running CodeScannerShield on {text[50:]}") log.info(f"Running CodeScannerShield on {text[50:]}")
result = await CodeShield.scan_code(text) result = await CodeShield.scan_code(text)
@ -69,3 +71,42 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])}, metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
) )
return RunShieldResponse(violation=violation) return RunShieldResponse(violation=violation)
def get_moderation_object_results(self, scan_result: CodeShieldScanResult) -> ModerationObjectResults:
categories = {}
category_scores = {}
category_applied_input_types = {}
flagged = scan_result.is_insecure
user_message = None
metadata = {}
# TODO check both list of inputs and single input; add some unit tests
if scan_result.is_insecure:
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
categories = dict.fromkeys(pattern_ids, True)
category_scores = dict.fromkeys(pattern_ids, 1.0)
category_applied_input_types = {key: ["text"] for key in pattern_ids}
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
return ModerationObjectResults(
flagged=flagged,
categories=categories,
category_scores=category_scores,
category_applied_input_types=category_applied_input_types,
user_message=user_message,
metadata=metadata,
)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
inputs = input if isinstance(input, list) else [input]
results = []
for text_input in inputs:
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
scan_result = await CodeShield.scan_code(text_input)
moderation_result = self.get_moderation_object_results(scan_result)
results.append(moderation_result)
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)

View file

@ -163,6 +163,58 @@ def test_safety_with_code_scanner(client_with_models, code_scanner_shield_id, mo
assert response.violation.user_message == "Sorry, I found security concerns in the code." assert response.violation.user_message == "Sorry, I found security concerns in the code."
# The code scanning shield uses Meta's Code Shield library to detect violations
# in which an LLM generates insecure code. Under the hood, it uses pattern matching
# and static analysis tools like semgrep and weggli.
def test_unsafe_input_with_code_scanner_moderations(client_with_models, code_scanner_shield_id, model_providers):
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
if not code_safety_enabled:
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")
insecure_code = """
Here is an example of code that hashes a password using MD5:
```
import hashlib
password = 'securepassword123'.encode()
hashed_password = hashlib.md5(password).hexdigest()
print("Insecure hashed password:", hashed_password)
```
"""
moderation_object = client_with_models.moderations.create(
input=insecure_code,
model=code_scanner_shield_id,
)
assert moderation_object.results[0].flagged is True, "Code scanner should have flagged the code as insecure"
assert all(value is True for value in moderation_object.results[0].categories.values()), (
"Code scanner shield should have detected code insecure category"
)
def test_safe_input_with_code_scanner_moderations_api(client_with_models, code_scanner_shield_id, model_providers):
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
if not code_safety_enabled:
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")
secure_code = """
Extract the first 5 characters from a string:
```
text = "Hello World"
first_five = text[:5]
print(first_five) # Output: "Hello"
# Safe handling for strings shorter than 5 characters
def get_first_five(text):
return text[:5] if text else ""
```
"""
moderation_object = client_with_models.moderations.create(
input=secure_code,
model=code_scanner_shield_id,
)
assert moderation_object.results[0].flagged is False, "Code scanner should not have flagged the code as insecure"
# We can use an instance of the LlamaGuard shield to detect attempts to misuse # We can use an instance of the LlamaGuard shield to detect attempts to misuse
# the interpreter as this is one of the existing categories it checks for # the interpreter as this is one of the existing categories it checks for
def test_safety_with_code_interpreter_abuse(client_with_models, shield_id): def test_safety_with_code_interpreter_abuse(client_with_models, shield_id):