add safety/list_shields to query available shields

This commit is contained in:
Xi Yan 2024-09-22 01:08:32 -07:00
parent 44fe099866
commit b8914bb56f
3 changed files with 52 additions and 3 deletions

View file

@ -80,6 +80,27 @@ class MetaReferenceSafetyImpl(Safety):
return RunShieldResponse(violation=violation)
async def list_shields(self) -> ListShieldsResponse:
supported_sheilds = [
v.value for v in MetaReferenceShieldType if self.is_supported(v)
]
return ListShieldsResponse(shields=supported_sheilds)
def is_supported(self, typ: MetaReferenceShieldType) -> bool:
if typ == MetaReferenceShieldType.llama_guard:
return self.config.llama_guard_shield is not None
if typ == MetaReferenceShieldType.jailbreak_shield:
return self.config.prompt_guard_shield is not None
if typ == MetaReferenceShieldType.injection_shield:
return self.config.prompt_guard_shield is not None
if typ == MetaReferenceShieldType.code_scanner_guard:
return True
return False
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config
if typ == MetaReferenceShieldType.llama_guard: