updating license for toolchain

This commit is contained in:
Ashwin Bharambe 2024-07-22 20:31:42 -07:00
parent 0e2fc9966a
commit 86fff23a9e
74 changed files with 512 additions and 94 deletions

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described found in the
# LICENSE file in the root directory of this source tree.
import re
from string import Template
@ -100,7 +106,7 @@ class LlamaGuardShield(ShieldBase):
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = [],
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
@ -119,7 +125,7 @@ class LlamaGuardShield(ShieldBase):
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = [],
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
):
@ -129,6 +135,8 @@ class LlamaGuardShield(ShieldBase):
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"