diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 9950064a4..f201d550f 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -73,18 +73,21 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [ CAT_ELECTIONS, ] -LLAMA_GUARD_MODEL_IDS = [ - CoreModelId.llama_guard_3_8b.value, - CoreModelId.llama_guard_3_1b.value, - CoreModelId.llama_guard_3_11b_vision.value, -] +# accept both CoreModelId and huggingface repo id +LLAMA_GUARD_MODEL_IDS = { + CoreModelId.llama_guard_3_8b.value: "meta-llama/Llama-Guard-3-8B", + "meta-llama/Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B", + CoreModelId.llama_guard_3_1b.value: "meta-llama/Llama-Guard-3-1B", + "meta-llama/Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B", + CoreModelId.llama_guard_3_11b_vision.value: "meta-llama/Llama-Guard-3-11B-Vision", + "meta-llama/Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision", +} MODEL_TO_SAFETY_CATEGORIES_MAP = { - CoreModelId.llama_guard_3_8b.value: ( - DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE] - ), - CoreModelId.llama_guard_3_1b.value: DEFAULT_LG_V3_SAFETY_CATEGORIES, - CoreModelId.llama_guard_3_11b_vision.value: DEFAULT_LG_V3_SAFETY_CATEGORIES, + "meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES + + [CAT_CODE_INTERPRETER_ABUSE], + "meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES, + "meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES, } @@ -150,8 +153,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if len(messages) > 0 and messages[0].role != Role.user.value: messages[0] = UserMessage(content=messages[0].content) + model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id] impl = LlamaGuardShield( - model=shield.provider_resource_id, + model=model, inference_api=self.inference_api, excluded_categories=self.config.excluded_categories, )