more fixes (some fixes to pre-existing issues in safety fixture)

This commit is contained in:
Ashwin Bharambe 2024-11-11 09:09:47 -08:00
parent 7507cd487f
commit 15ffceb533
6 changed files with 25 additions and 9 deletions

View file

@ -128,7 +128,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
if shield.shield_type != ShieldType.llama_guard.value:
print(f"Registering shield {shield}")
if shield.shield_type != ShieldType.llama_guard:
raise ValueError(f"Unsupported shield type: {shield.shield_type}")
async def run_shield(