accept huggingface repo IDs as shield ids for llama guard

This commit is contained in:
Ashwin Bharambe 2024-11-18 13:42:13 -08:00
parent a562668dcd
commit 38563d7c00

View file

@ -73,18 +73,21 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_ELECTIONS, CAT_ELECTIONS,
] ]
LLAMA_GUARD_MODEL_IDS = [ # accept both CoreModelId and huggingface repo id
CoreModelId.llama_guard_3_8b.value, LLAMA_GUARD_MODEL_IDS = {
CoreModelId.llama_guard_3_1b.value, CoreModelId.llama_guard_3_8b.value: "meta-llama/Llama-Guard-3-8B",
CoreModelId.llama_guard_3_11b_vision.value, "meta-llama/Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
] CoreModelId.llama_guard_3_1b.value: "meta-llama/Llama-Guard-3-1B",
"meta-llama/Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
CoreModelId.llama_guard_3_11b_vision.value: "meta-llama/Llama-Guard-3-11B-Vision",
"meta-llama/Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
}
MODEL_TO_SAFETY_CATEGORIES_MAP = { MODEL_TO_SAFETY_CATEGORIES_MAP = {
CoreModelId.llama_guard_3_8b.value: ( "meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES
DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE] + [CAT_CODE_INTERPRETER_ABUSE],
), "meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
CoreModelId.llama_guard_3_1b.value: DEFAULT_LG_V3_SAFETY_CATEGORIES, "meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
CoreModelId.llama_guard_3_11b_vision.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
} }
@ -150,8 +153,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value: if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content) messages[0] = UserMessage(content=messages[0].content)
model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id]
impl = LlamaGuardShield( impl = LlamaGuardShield(
model=shield.provider_resource_id, model=model,
inference_api=self.inference_api, inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories, excluded_categories=self.config.excluded_categories,
) )