feat: Moderation api for Code Scanner Provider

This commit is contained in:
Swapna Lekkala 2025-08-11 13:45:36 -07:00
parent 61582f327c
commit ef15c74307
3 changed files with 99 additions and 6 deletions

View file

@ -5,8 +5,11 @@
# the root directory of this source tree.
import logging
import uuid
from typing import Any
from codeshield.cs import CodeShield, CodeShieldScanResult
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
@ -14,6 +17,7 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
@ -55,8 +59,6 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
if not shield:
raise ValueError(f"Shield {shield_id} not found")
from codeshield.cs import CodeShield
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
log.info(f"Running CodeScannerShield on {text[50:]}")
result = await CodeShield.scan_code(text)
@ -69,3 +71,42 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
)
return RunShieldResponse(violation=violation)
def get_moderation_object_results(self, scan_result: CodeShieldScanResult) -> ModerationObjectResults:
categories = {}
category_scores = {}
category_applied_input_types = {}
flagged = scan_result.is_insecure
user_message = None
metadata = {}
# TODO check both list of inputs and single input; add some unit tests
if scan_result.is_insecure:
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
categories = dict.fromkeys(pattern_ids, True)
category_scores = dict.fromkeys(pattern_ids, 1.0)
category_applied_input_types = {key: ["text"] for key in pattern_ids}
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
return ModerationObjectResults(
flagged=flagged,
categories=categories,
category_scores=category_scores,
category_applied_input_types=category_applied_input_types,
user_message=user_message,
metadata=metadata,
)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
inputs = input if isinstance(input, list) else [input]
results = []
for text_input in inputs:
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
scan_result = await CodeShield.scan_code(text_input)
moderation_result = self.get_moderation_object_results(scan_result)
results.append(moderation_result)
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)