feat: Code scanner Provider impl for moderations api (#3100)

# What does this PR do?
Add CodeScanner implementations

## Test Plan
`SAFETY_MODEL=CodeScanner LLAMA_STACK_CONFIG=starter uv run pytest -v
tests/integration/safety/test_safety.py
--text-model=llama3.2:3b-instruct-fp16
--embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama`

This PR need to land after this
https://github.com/meta-llama/llama-stack/pull/3098
This commit is contained in:
slekkala1 2025-08-18 14:15:40 -07:00 committed by GitHub
parent 27d6becfd0
commit 7519ab4024
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 144 additions and 24 deletions

View file

@ -5,7 +5,11 @@
# the root directory of this source tree.
import logging
from typing import Any
import uuid
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from codeshield.cs import CodeShieldScanResult
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
@ -14,6 +18,7 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
@ -24,8 +29,8 @@ from .config import CodeScannerConfig
log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner",
"CodeShield",
"code-scanner",
"code-shield",
]
@ -69,3 +74,55 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
)
return RunShieldResponse(violation=violation)
def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults:
categories = {}
category_scores = {}
category_applied_input_types = {}
flagged = scan_result.is_insecure
user_message = None
metadata = {}
if scan_result.is_insecure:
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
categories = dict.fromkeys(pattern_ids, True)
category_scores = dict.fromkeys(pattern_ids, 1.0)
category_applied_input_types = {key: ["text"] for key in pattern_ids}
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
return ModerationObjectResults(
flagged=flagged,
categories=categories,
category_scores=category_scores,
category_applied_input_types=category_applied_input_types,
user_message=user_message,
metadata=metadata,
)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
inputs = input if isinstance(input, list) else [input]
results = []
from codeshield.cs import CodeShield
for text_input in inputs:
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
try:
scan_result = await CodeShield.scan_code(text_input)
moderation_result = self.get_moderation_object_results(scan_result)
except Exception as e:
log.error(f"CodeShield.scan_code failed: {e}")
# create safe fallback response on scanner failure to avoid blocking legitimate requests
moderation_result = ModerationObjectResults(
flagged=False,
categories={},
category_scores={},
category_applied_input_types={},
user_message=None,
metadata={"scanner_error": str(e)},
)
results.append(moderation_result)
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)