formatting

This commit is contained in:
Dalton Flanagan 2024-08-14 14:22:25 -04:00
parent 94dfa293a6
commit b6ccaf1778
33 changed files with 110 additions and 97 deletions

View file

@ -41,7 +41,6 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None:
self.config = config

View file

@ -14,7 +14,6 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class ShieldBase(ABC):
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
@ -60,7 +59,6 @@ class TextShield(ShieldBase):
class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType:
return "dummy"

View file

@ -12,7 +12,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard

View file

@ -100,7 +100,6 @@ PROMPT_TEMPLATE = Template(
class LlamaGuardShield(ShieldBase):
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
@ -166,7 +165,6 @@ class LlamaGuardShield(ShieldBase):
return None
def get_safety_categories(self) -> List[str]:
excluded_categories = self.excluded_categories
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
excluded_categories = []
@ -181,7 +179,6 @@ class LlamaGuardShield(ShieldBase):
return categories
def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
@ -225,7 +222,6 @@ class LlamaGuardShield(ShieldBase):
is_violation=False,
)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",

View file

@ -18,7 +18,6 @@ from llama_toolchain.safety.api.datatypes import * # noqa: F403
class PromptGuardShield(TextShield):
class Mode(Enum):
INJECTION = auto()
JAILBREAK = auto()