diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 4d1783b94..6792c9579 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -221,6 +221,44 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): return await impl.run(messages) + async def create( + self, + input: str | list[str], + model: str | None = None, # To replace with default model for llama-guard + ) -> ModerationObject: + if model is None: + raise ValueError("Model cannot be None") + if not input or len(input) == 0: + raise ValueError("Input cannot be empty") + if isinstance(input, list): + messages = input.copy() + else: + messages = [input] + + # convert to user messages format with role + messages = [UserMessage(content=m) for m in messages] + + # Use the inference API's model resolution instead of hardcoded mappings + # This allows the shield to work with any registered model + # Determine safety categories based on the model type + # For known Llama Guard models, use specific categories + if model in LLAMA_GUARD_MODEL_IDS: + # Use the mapped model for categories but the original model_id for inference + mapped_model = LLAMA_GUARD_MODEL_IDS[model] + safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES) + else: + # For unknown models, use default Llama Guard 3 8B categories + safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE] + + impl = LlamaGuardShield( + model=model, + inference_api=self.inference_api, + excluded_categories=self.config.excluded_categories, + safety_categories=safety_categories, + ) + + return await impl.run_create(messages) + async def create( self,