From ab829b055736e87173b6748000cfa0525469e2e6 Mon Sep 17 00:00:00 2001 From: Kate Plawiak Date: Mon, 22 Jul 2024 22:09:44 -0700 Subject: [PATCH] revert excluded cat defaults --- llama_toolchain/safety/shields/llama_guard.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index d6154f027..5234c8e1f 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -111,7 +111,7 @@ class LlamaGuardShield(ShieldBase): def instance( on_violation_action=OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = [], + excluded_categories: List[str] = None, disable_input_check: bool = False, disable_output_check: bool = False, ) -> "LlamaGuardShield": @@ -130,7 +130,7 @@ class LlamaGuardShield(ShieldBase): self, on_violation_action: OnViolationAction = OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = [], + excluded_categories: List[str] = None, disable_input_check: bool = False, disable_output_check: bool = False, ): @@ -140,6 +140,9 @@ class LlamaGuardShield(ShieldBase): assert model_dir is not None, "Llama Guard model_dir is None" + if excluded_categories is None: + excluded_categories = [] + assert len(excluded_categories) == 0 or all( x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"