diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index d6154f027..5234c8e1f 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -111,7 +111,7 @@ class LlamaGuardShield(ShieldBase): def instance( on_violation_action=OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = [], + excluded_categories: List[str] = None, disable_input_check: bool = False, disable_output_check: bool = False, ) -> "LlamaGuardShield": @@ -130,7 +130,7 @@ class LlamaGuardShield(ShieldBase): self, on_violation_action: OnViolationAction = OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = [], + excluded_categories: List[str] = None, disable_input_check: bool = False, disable_output_check: bool = False, ): @@ -140,6 +140,9 @@ class LlamaGuardShield(ShieldBase): assert model_dir is not None, "Llama Guard model_dir is None" + if excluded_categories is None: + excluded_categories = [] + assert len(excluded_categories) == 0 or all( x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"