accept huggingface repo IDs as shield ids for llama guard

This commit is contained in:
Ashwin Bharambe 2024-11-18 13:42:13 -08:00
parent a562668dcd
commit 38563d7c00

View file

@ -73,18 +73,21 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_ELECTIONS,
]
LLAMA_GUARD_MODEL_IDS = [
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
# accept both CoreModelId and huggingface repo id
LLAMA_GUARD_MODEL_IDS = {
CoreModelId.llama_guard_3_8b.value: "meta-llama/Llama-Guard-3-8B",
"meta-llama/Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
CoreModelId.llama_guard_3_1b.value: "meta-llama/Llama-Guard-3-1B",
"meta-llama/Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
CoreModelId.llama_guard_3_11b_vision.value: "meta-llama/Llama-Guard-3-11B-Vision",
"meta-llama/Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
}
MODEL_TO_SAFETY_CATEGORIES_MAP = {
CoreModelId.llama_guard_3_8b.value: (
DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
),
CoreModelId.llama_guard_3_1b.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
CoreModelId.llama_guard_3_11b_vision.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES
+ [CAT_CODE_INTERPRETER_ABUSE],
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
}
@ -150,8 +153,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id]
impl = LlamaGuardShield(
model=shield.provider_resource_id,
model=model,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
)