diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 17577b0c9..25f1805b9 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -127,27 +127,19 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): async def shutdown(self) -> None: pass - async def register_shield(self, shield: ShieldDef) -> None: - raise ValueError("Registering dynamic shields is not supported") - - async def list_shields(self) -> List[ShieldDef]: - return [ - ShieldDef( - identifier=ShieldType.llama_guard.value, - shield_type=ShieldType.llama_guard.value, - params={}, - ), - ] + async def register_shield(self, shield: Shield) -> None: + if shield.shield_type != ShieldType.llama_guard.value: + raise ValueError(f"Unsupported shield type: {shield.shield_type}") async def run_shield( self, - identifier: str, + shield_id: str, messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: - shield_def = await self.shield_store.get_shield(identifier) - if not shield_def: - raise ValueError(f"Unknown shield {identifier}") + shield = await self.shield_store.get_shield(shield_id) + if not shield: + raise ValueError(f"Unknown shield {shield_id}") messages = messages.copy() # some shields like llama-guard require the first message to be a user message diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 5cfafcde4..7754607ec 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -7,11 +7,11 @@ from typing import Any, Dict, List import torch - -from llama_stack.distribution.utils.model_utils import model_local_dir from termcolor import cprint from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -35,27 +35,19 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): async def shutdown(self) -> None: pass - async def register_shield(self, shield: ShieldDef) -> None: - raise ValueError("Registering dynamic shields is not supported") - - async def list_shields(self) -> List[ShieldDef]: - return [ - ShieldDef( - identifier=ShieldType.prompt_guard.value, - shield_type=ShieldType.prompt_guard.value, - params={}, - ) - ] + async def register_shield(self, shield: Shield) -> None: + if shield.shield_type != ShieldType.prompt_guard.value: + raise ValueError(f"Unsupported shield type: {shield.shield_type}") async def run_shield( self, - identifier: str, + shield_id: str, messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: - shield_def = await self.shield_store.get_shield(identifier) - if not shield_def: - raise ValueError(f"Unknown shield {identifier}") + shield = await self.shield_store.get_shield(shield_id) + if not shield: + raise ValueError(f"Unknown shield {shield_id}") return await self.shield.run(messages)