mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
revert excluded cat defaults
This commit is contained in:
parent
ab8a220faa
commit
ab829b0557
1 changed files with 5 additions and 2 deletions
|
@ -111,7 +111,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
def instance(
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
excluded_categories: List[str] = [],
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
) -> "LlamaGuardShield":
|
||||
|
@ -130,7 +130,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
excluded_categories: List[str] = [],
|
||||
excluded_categories: List[str] = None,
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
):
|
||||
|
@ -140,6 +140,9 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
assert model_dir is not None, "Llama Guard model_dir is None"
|
||||
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
|
||||
assert len(excluded_categories) == 0 or all(
|
||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue