mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-21 09:23:13 +00:00
feat: Code scanner Provider impl for moderations api (#3100)
# What does this PR do? Add CodeScanner implementations ## Test Plan `SAFETY_MODEL=CodeScanner LLAMA_STACK_CONFIG=starter uv run pytest -v tests/integration/safety/test_safety.py --text-model=llama3.2:3b-instruct-fp16 --embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama` This PR need to land after this https://github.com/meta-llama/llama-stack/pull/3098
This commit is contained in:
parent
27d6becfd0
commit
7519ab4024
9 changed files with 144 additions and 24 deletions
|
@ -5,7 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeshield.cs import CodeShieldScanResult
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import (
|
||||
|
@ -14,6 +18,7 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
@ -24,8 +29,8 @@ from .config import CodeScannerConfig
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||
"CodeScanner",
|
||||
"CodeShield",
|
||||
"code-scanner",
|
||||
"code-shield",
|
||||
]
|
||||
|
||||
|
||||
|
@ -69,3 +74,55 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults:
|
||||
categories = {}
|
||||
category_scores = {}
|
||||
category_applied_input_types = {}
|
||||
|
||||
flagged = scan_result.is_insecure
|
||||
user_message = None
|
||||
metadata = {}
|
||||
|
||||
if scan_result.is_insecure:
|
||||
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
|
||||
categories = dict.fromkeys(pattern_ids, True)
|
||||
category_scores = dict.fromkeys(pattern_ids, 1.0)
|
||||
category_applied_input_types = {key: ["text"] for key in pattern_ids}
|
||||
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
|
||||
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
|
||||
|
||||
return ModerationObjectResults(
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_scores=category_scores,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
inputs = input if isinstance(input, list) else [input]
|
||||
results = []
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
for text_input in inputs:
|
||||
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
|
||||
try:
|
||||
scan_result = await CodeShield.scan_code(text_input)
|
||||
moderation_result = self.get_moderation_object_results(scan_result)
|
||||
except Exception as e:
|
||||
log.error(f"CodeShield.scan_code failed: {e}")
|
||||
# create safe fallback response on scanner failure to avoid blocking legitimate requests
|
||||
moderation_result = ModerationObjectResults(
|
||||
flagged=False,
|
||||
categories={},
|
||||
category_scores={},
|
||||
category_applied_input_types={},
|
||||
user_message=None,
|
||||
metadata={"scanner_error": str(e)},
|
||||
)
|
||||
results.append(moderation_result)
|
||||
|
||||
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
|
||||
|
|
|
@ -11,11 +11,7 @@ from string import Template
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
Message,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference, Message, UserMessage
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
@ -72,7 +68,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
|
|||
}
|
||||
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
||||
|
||||
|
||||
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
||||
CAT_VIOLENT_CRIMES,
|
||||
CAT_NON_VIOLENT_CRIMES,
|
||||
|
@ -460,7 +455,7 @@ class LlamaGuardShield:
|
|||
|
||||
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
|
||||
"""Check if content is safe based on response and unsafe code."""
|
||||
if response.strip() == SAFE_RESPONSE:
|
||||
if response.strip().lower().startswith(SAFE_RESPONSE):
|
||||
return True
|
||||
|
||||
if unsafe_code:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue