mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
remove unwanted
This commit is contained in:
parent
0b429da496
commit
900424ef61
2 changed files with 9 additions and 4 deletions
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from codeshield.cs import CodeShield, CodeShieldScanResult
|
if TYPE_CHECKING:
|
||||||
|
from codeshield.cs import CodeShieldScanResult
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import (
|
from llama_stack.apis.safety import (
|
||||||
|
@ -59,6 +60,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
if not shield:
|
if not shield:
|
||||||
raise ValueError(f"Shield {shield_id} not found")
|
raise ValueError(f"Shield {shield_id} not found")
|
||||||
|
|
||||||
|
from codeshield.cs import CodeShield
|
||||||
|
|
||||||
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
||||||
log.info(f"Running CodeScannerShield on {text[50:]}")
|
log.info(f"Running CodeScannerShield on {text[50:]}")
|
||||||
result = await CodeShield.scan_code(text)
|
result = await CodeShield.scan_code(text)
|
||||||
|
@ -72,7 +75,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
)
|
)
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
def get_moderation_object_results(self, scan_result: CodeShieldScanResult) -> ModerationObjectResults:
|
def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults:
|
||||||
categories = {}
|
categories = {}
|
||||||
category_scores = {}
|
category_scores = {}
|
||||||
category_applied_input_types = {}
|
category_applied_input_types = {}
|
||||||
|
@ -102,6 +105,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
inputs = input if isinstance(input, list) else [input]
|
inputs = input if isinstance(input, list) else [input]
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
from codeshield.cs import CodeShield
|
||||||
|
|
||||||
for text_input in inputs:
|
for text_input in inputs:
|
||||||
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
|
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
|
||||||
scan_result = await CodeShield.scan_code(text_input)
|
scan_result = await CodeShield.scan_code(text_input)
|
||||||
|
|
|
@ -455,7 +455,7 @@ class LlamaGuardShield:
|
||||||
|
|
||||||
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
|
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
|
||||||
"""Check if content is safe based on response and unsafe code."""
|
"""Check if content is safe based on response and unsafe code."""
|
||||||
if response.strip() == SAFE_RESPONSE:
|
if response.strip().lower().startswith(SAFE_RESPONSE):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if unsafe_code:
|
if unsafe_code:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue