diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index d60600b4a..0b74fd259 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -48,5 +48,5 @@ class Safety(Protocol): @webmethod(route="/safety/run_shield") async def run_shield( - self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + self, identifier: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 173b28483..760dbaf2f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -154,12 +154,12 @@ class SafetyRouter(Safety): async def run_shield( self, - shield_type: str, + identifier: str, messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: return await self.routing_table.get_provider_impl(identifier).run_shield( - shield_type=shield_type, + identifier=identifier, messages=messages, params=params, ) diff --git a/llama_stack/providers/inline/meta_reference/safety/safety.py b/llama_stack/providers/inline/meta_reference/safety/safety.py index 28c78b65c..2d0db7624 100644 --- a/llama_stack/providers/inline/meta_reference/safety/safety.py +++ b/llama_stack/providers/inline/meta_reference/safety/safety.py @@ -57,13 +57,13 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): async def run_shield( self, - shield_type: str, + identifier: str, messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: - shield_def = await self.shield_store.get_shield(shield_type) + shield_def = await self.shield_store.get_shield(identifier) if not shield_def: - raise ValueError(f"Unknown shield {shield_type}") + raise ValueError(f"Unknown shield {identifier}") shield = self.get_shield_impl(shield_def)