From 67bfd946dc17d9d6d5f06ec2bf17edca61abb2a6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 6 Nov 2024 16:11:36 -0800 Subject: [PATCH] fix safety sig --- llama_stack/apis/safety/safety.py | 2 +- llama_stack/distribution/routers/routers.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 0b74fd259..d60600b4a 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -48,5 +48,5 @@ class Safety(Protocol): @webmethod(route="/safety/run_shield") async def run_shield( - self, identifier: str, messages: List[Message], params: Dict[str, Any] = None + self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 760dbaf2f..173b28483 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -154,12 +154,12 @@ class SafetyRouter(Safety): async def run_shield( self, - identifier: str, + shield_type: str, messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: return await self.routing_table.get_provider_impl(identifier).run_shield( - identifier=identifier, + shield_type=shield_type, messages=messages, params=params, )