safety api

This commit is contained in:
Xi Yan 2024-09-11 13:41:15 -07:00
parent 959c499cac
commit 4b34f741d0
3 changed files with 9 additions and 11 deletions

View file

@ -86,5 +86,6 @@ class Safety(Protocol):
@webmethod(route="/safety/run_shields") @webmethod(route="/safety/run_shields")
async def run_shields( async def run_shields(
self, self,
request: RunShieldRequest, messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse: ... ) -> RunShieldResponse: ...

View file

@ -13,10 +13,10 @@ import fire
import httpx import httpx
from llama_models.llama3.api.datatypes import UserMessage from llama_models.llama3.api.datatypes import UserMessage
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig from llama_toolchain.core.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from .api import * # noqa: F403 from .api import * # noqa: F403
@ -43,9 +43,7 @@ class SafetyClient(Safety):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/safety/run_shields", f"{self.base_url}/safety/run_shields",
json={ json=encodable_dict(request),
"request": encodable_dict(request),
},
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) )

View file

@ -52,13 +52,12 @@ class MetaReferenceSafetyImpl(Safety):
async def run_shields( async def run_shields(
self, self,
request: RunShieldRequest, messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse: ) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in request.shields] shields = [shield_config_to_shield(c, self.config) for c in shields]
responses = await asyncio.gather( responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
*[shield.run(request.messages) for shield in shields]
)
return RunShieldResponse(responses=responses) return RunShieldResponse(responses=responses)