mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-15 01:26:10 +00:00
formatting
This commit is contained in:
parent
94dfa293a6
commit
b6ccaf1778
33 changed files with 110 additions and 97 deletions
|
@ -41,7 +41,6 @@ def resolve_and_get_path(model_name: str) -> str:
|
|||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
|
||||
def __init__(self, config: SafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
|||
|
||||
|
||||
class ShieldBase(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
|
@ -60,7 +59,6 @@ class TextShield(ShieldBase):
|
|||
|
||||
|
||||
class DummyShield(TextShield):
|
||||
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return "dummy"
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
|||
|
||||
|
||||
class CodeScannerShield(TextShield):
|
||||
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return BuiltinShield.code_scanner_guard
|
||||
|
||||
|
|
|
@ -100,7 +100,6 @@ PROMPT_TEMPLATE = Template(
|
|||
|
||||
|
||||
class LlamaGuardShield(ShieldBase):
|
||||
|
||||
@staticmethod
|
||||
def instance(
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
|
@ -166,7 +165,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
return None
|
||||
|
||||
def get_safety_categories(self) -> List[str]:
|
||||
|
||||
excluded_categories = self.excluded_categories
|
||||
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
|
||||
excluded_categories = []
|
||||
|
@ -181,7 +179,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
return categories
|
||||
|
||||
def build_prompt(self, messages: List[Message]) -> str:
|
||||
|
||||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
|
@ -225,7 +222,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
is_violation=False,
|
||||
)
|
||||
else:
|
||||
|
||||
prompt = self.build_prompt(messages)
|
||||
llama_guard_input = {
|
||||
"role": "user",
|
||||
|
|
|
@ -18,7 +18,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
|||
|
||||
|
||||
class PromptGuardShield(TextShield):
|
||||
|
||||
class Mode(Enum):
|
||||
INJECTION = auto()
|
||||
JAILBREAK = auto()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue