Merge branch 'main' into fix_llama_guard_inference

This commit is contained in:
Kate Plawiak 2024-07-22 21:31:18 -07:00
commit 7a8b5c1604
75 changed files with 979 additions and 182 deletions

View file

@ -1,3 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import re
from string import Template
@ -100,7 +112,7 @@ class LlamaGuardShield(ShieldBase):
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = [],
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
@ -119,7 +131,7 @@ class LlamaGuardShield(ShieldBase):
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = [],
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
):
@ -129,6 +141,8 @@ class LlamaGuardShield(ShieldBase):
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"