more fixes (some fixes to pre-existing issues in safety fixture)

This commit is contained in:
Ashwin Bharambe 2024-11-11 09:09:47 -08:00
parent 7507cd487f
commit 15ffceb533
6 changed files with 25 additions and 9 deletions

View file

@ -25,7 +25,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
pass
async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.code_scanner.value:
if shield.shield_type != ShieldType.code_scanner:
raise ValueError(f"Unsupported safety shield type: {shield.shield_type}")
async def run_shield(

View file

@ -128,7 +128,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.llama_guard.value:
print(f"Registering shield {shield}")
if shield.shield_type != ShieldType.llama_guard:
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
async def run_shield(

View file

@ -36,7 +36,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.prompt_guard.value:
if shield.shield_type != ShieldType.prompt_guard:
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
async def run_shield(