diff --git a/llama_toolchain/safety/api/api.py b/llama_toolchain/safety/api/api.py index 96682d172..631cfa992 100644 --- a/llama_toolchain/safety/api/api.py +++ b/llama_toolchain/safety/api/api.py @@ -86,5 +86,6 @@ class Safety(Protocol): @webmethod(route="/safety/run_shields") async def run_shields( self, - request: RunShieldRequest, + messages: List[Message], + shields: List[ShieldDefinition], ) -> RunShieldResponse: ... diff --git a/llama_toolchain/safety/client.py b/llama_toolchain/safety/client.py index 0cf7deae8..26a9813b3 100644 --- a/llama_toolchain/safety/client.py +++ b/llama_toolchain/safety/client.py @@ -13,10 +13,10 @@ import fire import httpx from llama_models.llama3.api.datatypes import UserMessage -from pydantic import BaseModel -from termcolor import cprint from llama_toolchain.core.datatypes import RemoteProviderConfig +from pydantic import BaseModel +from termcolor import cprint from .api import * # noqa: F403 @@ -43,9 +43,7 @@ class SafetyClient(Safety): async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/safety/run_shields", - json={ - "request": encodable_dict(request), - }, + json=encodable_dict(request), headers={"Content-Type": "application/json"}, timeout=20, ) diff --git a/llama_toolchain/safety/meta_reference/safety.py b/llama_toolchain/safety/meta_reference/safety.py index e71ac09a2..6c75e74e8 100644 --- a/llama_toolchain/safety/meta_reference/safety.py +++ b/llama_toolchain/safety/meta_reference/safety.py @@ -52,13 +52,12 @@ class MetaReferenceSafetyImpl(Safety): async def run_shields( self, - request: RunShieldRequest, + messages: List[Message], + shields: List[ShieldDefinition], ) -> RunShieldResponse: - shields = [shield_config_to_shield(c, self.config) for c in request.shields] + shields = [shield_config_to_shield(c, self.config) for c in shields] - responses = await asyncio.gather( - *[shield.run(request.messages) for shield in shields] - ) + responses = await asyncio.gather(*[shield.run(messages) for shield in shields]) return RunShieldResponse(responses=responses)