safety api

This commit is contained in:
Xi Yan 2024-09-11 13:41:15 -07:00
parent 959c499cac
commit 4b34f741d0
3 changed files with 9 additions and 11 deletions

View file

@ -52,13 +52,12 @@ class MetaReferenceSafetyImpl(Safety):
async def run_shields(
self,
request: RunShieldRequest,
messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in request.shields]
shields = [shield_config_to_shield(c, self.config) for c in shields]
responses = await asyncio.gather(
*[shield.run(request.messages) for shield in shields]
)
responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
return RunShieldResponse(responses=responses)