mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
feat: Moderation api for Code Scanner Provider
This commit is contained in:
parent
61582f327c
commit
ef15c74307
3 changed files with 99 additions and 6 deletions
|
@ -5,8 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from codeshield.cs import CodeShield, CodeShieldScanResult
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
|
@ -14,6 +17,7 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
@ -55,8 +59,6 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
||||
log.info(f"Running CodeScannerShield on {text[50:]}")
|
||||
result = await CodeShield.scan_code(text)
|
||||
|
@ -69,3 +71,42 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_moderation_object_results(self, scan_result: CodeShieldScanResult) -> ModerationObjectResults:
|
||||
categories = {}
|
||||
category_scores = {}
|
||||
category_applied_input_types = {}
|
||||
|
||||
flagged = scan_result.is_insecure
|
||||
user_message = None
|
||||
metadata = {}
|
||||
# TODO check both list of inputs and single input; add some unit tests
|
||||
|
||||
if scan_result.is_insecure:
|
||||
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
|
||||
categories = dict.fromkeys(pattern_ids, True)
|
||||
category_scores = dict.fromkeys(pattern_ids, 1.0)
|
||||
category_applied_input_types = {key: ["text"] for key in pattern_ids}
|
||||
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
|
||||
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
|
||||
|
||||
return ModerationObjectResults(
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_scores=category_scores,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
inputs = input if isinstance(input, list) else [input]
|
||||
results = []
|
||||
|
||||
for text_input in inputs:
|
||||
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
|
||||
scan_result = await CodeShield.scan_code(text_input)
|
||||
moderation_result = self.get_moderation_object_results(scan_result)
|
||||
results.append(moderation_result)
|
||||
|
||||
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue