diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 0b74fd259..d60600b4a 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -48,5 +48,5 @@ class Safety(Protocol): @webmethod(route="/safety/run_shield") async def run_shield( - self, identifier: str, messages: List[Message], params: Dict[str, Any] = None + self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 760dbaf2f..173b28483 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -154,12 +154,12 @@ class SafetyRouter(Safety): async def run_shield( self, - identifier: str, + shield_type: str, messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: return await self.routing_table.get_provider_impl(identifier).run_shield( - identifier=identifier, + shield_type=shield_type, messages=messages, params=params, )