mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
feat: Moderation api for Code Scanner Provider
This commit is contained in:
parent
61582f327c
commit
ef15c74307
3 changed files with 99 additions and 6 deletions
|
@ -130,8 +130,8 @@ providers:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: CodeScanner
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::code-scanner
|
||||||
config:
|
config:
|
||||||
excluded_categories: []
|
excluded_categories: []
|
||||||
agents:
|
agents:
|
||||||
|
@ -212,8 +212,8 @@ inference_store:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db
|
||||||
models: []
|
models: []
|
||||||
shields:
|
shields:
|
||||||
- shield_id: llama-guard
|
- shield_id: CodeScanner
|
||||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
provider_id: ${env.SAFETY_MODEL:+CodeScanner}
|
||||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -5,8 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from codeshield.cs import CodeShield, CodeShieldScanResult
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import (
|
from llama_stack.apis.safety import (
|
||||||
RunShieldResponse,
|
RunShieldResponse,
|
||||||
|
@ -14,6 +17,7 @@ from llama_stack.apis.safety import (
|
||||||
SafetyViolation,
|
SafetyViolation,
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
@ -55,8 +59,6 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
if not shield:
|
if not shield:
|
||||||
raise ValueError(f"Shield {shield_id} not found")
|
raise ValueError(f"Shield {shield_id} not found")
|
||||||
|
|
||||||
from codeshield.cs import CodeShield
|
|
||||||
|
|
||||||
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
||||||
log.info(f"Running CodeScannerShield on {text[50:]}")
|
log.info(f"Running CodeScannerShield on {text[50:]}")
|
||||||
result = await CodeShield.scan_code(text)
|
result = await CodeShield.scan_code(text)
|
||||||
|
@ -69,3 +71,42 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
|
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
|
||||||
)
|
)
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
|
def get_moderation_object_results(self, scan_result: CodeShieldScanResult) -> ModerationObjectResults:
|
||||||
|
categories = {}
|
||||||
|
category_scores = {}
|
||||||
|
category_applied_input_types = {}
|
||||||
|
|
||||||
|
flagged = scan_result.is_insecure
|
||||||
|
user_message = None
|
||||||
|
metadata = {}
|
||||||
|
# TODO check both list of inputs and single input; add some unit tests
|
||||||
|
|
||||||
|
if scan_result.is_insecure:
|
||||||
|
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
|
||||||
|
categories = dict.fromkeys(pattern_ids, True)
|
||||||
|
category_scores = dict.fromkeys(pattern_ids, 1.0)
|
||||||
|
category_applied_input_types = {key: ["text"] for key in pattern_ids}
|
||||||
|
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
|
||||||
|
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
|
||||||
|
|
||||||
|
return ModerationObjectResults(
|
||||||
|
flagged=flagged,
|
||||||
|
categories=categories,
|
||||||
|
category_scores=category_scores,
|
||||||
|
category_applied_input_types=category_applied_input_types,
|
||||||
|
user_message=user_message,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||||
|
inputs = input if isinstance(input, list) else [input]
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for text_input in inputs:
|
||||||
|
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
|
||||||
|
scan_result = await CodeShield.scan_code(text_input)
|
||||||
|
moderation_result = self.get_moderation_object_results(scan_result)
|
||||||
|
results.append(moderation_result)
|
||||||
|
|
||||||
|
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
|
||||||
|
|
|
@ -163,6 +163,58 @@ def test_safety_with_code_scanner(client_with_models, code_scanner_shield_id, mo
|
||||||
assert response.violation.user_message == "Sorry, I found security concerns in the code."
|
assert response.violation.user_message == "Sorry, I found security concerns in the code."
|
||||||
|
|
||||||
|
|
||||||
|
# The code scanning shield uses Meta's Code Shield library to detect violations
|
||||||
|
# in which an LLM generates insecure code. Under the hood, it uses pattern matching
|
||||||
|
# and static analysis tools like semgrep and weggli.
|
||||||
|
def test_unsafe_input_with_code_scanner_moderations(client_with_models, code_scanner_shield_id, model_providers):
|
||||||
|
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
|
||||||
|
if not code_safety_enabled:
|
||||||
|
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")
|
||||||
|
|
||||||
|
insecure_code = """
|
||||||
|
Here is an example of code that hashes a password using MD5:
|
||||||
|
|
||||||
|
```
|
||||||
|
import hashlib
|
||||||
|
password = 'securepassword123'.encode()
|
||||||
|
hashed_password = hashlib.md5(password).hexdigest()
|
||||||
|
print("Insecure hashed password:", hashed_password)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
moderation_object = client_with_models.moderations.create(
|
||||||
|
input=insecure_code,
|
||||||
|
model=code_scanner_shield_id,
|
||||||
|
)
|
||||||
|
assert moderation_object.results[0].flagged is True, "Code scanner should have flagged the code as insecure"
|
||||||
|
assert all(value is True for value in moderation_object.results[0].categories.values()), (
|
||||||
|
"Code scanner shield should have detected code insecure category"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_safe_input_with_code_scanner_moderations_api(client_with_models, code_scanner_shield_id, model_providers):
|
||||||
|
code_safety_enabled = len(CODE_SCANNER_ENABLED_PROVIDERS & model_providers) > 0
|
||||||
|
if not code_safety_enabled:
|
||||||
|
pytest.skip(f"Testing code scanner shields is not supported for model_providers {model_providers}")
|
||||||
|
|
||||||
|
secure_code = """
|
||||||
|
Extract the first 5 characters from a string:
|
||||||
|
```
|
||||||
|
text = "Hello World"
|
||||||
|
first_five = text[:5]
|
||||||
|
print(first_five) # Output: "Hello"
|
||||||
|
|
||||||
|
# Safe handling for strings shorter than 5 characters
|
||||||
|
def get_first_five(text):
|
||||||
|
return text[:5] if text else ""
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
moderation_object = client_with_models.moderations.create(
|
||||||
|
input=secure_code,
|
||||||
|
model=code_scanner_shield_id,
|
||||||
|
)
|
||||||
|
assert moderation_object.results[0].flagged is False, "Code scanner should not have flagged the code as insecure"
|
||||||
|
|
||||||
|
|
||||||
# We can use an instance of the LlamaGuard shield to detect attempts to misuse
|
# We can use an instance of the LlamaGuard shield to detect attempts to misuse
|
||||||
# the interpreter as this is one of the existing categories it checks for
|
# the interpreter as this is one of the existing categories it checks for
|
||||||
def test_safety_with_code_interpreter_abuse(client_with_models, shield_id):
|
def test_safety_with_code_interpreter_abuse(client_with_models, shield_id):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue