safety api

This commit is contained in:
Xi Yan 2024-09-11 13:41:15 -07:00
parent 959c499cac
commit 4b34f741d0
3 changed files with 9 additions and 11 deletions

View file

@ -86,5 +86,6 @@ class Safety(Protocol):
@webmethod(route="/safety/run_shields")
async def run_shields(
self,
request: RunShieldRequest,
messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse: ...

View file

@ -13,10 +13,10 @@ import fire
import httpx
from llama_models.llama3.api.datatypes import UserMessage
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from .api import * # noqa: F403
@ -43,9 +43,7 @@ class SafetyClient(Safety):
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/safety/run_shields",
json={
"request": encodable_dict(request),
},
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)

View file

@ -52,13 +52,12 @@ class MetaReferenceSafetyImpl(Safety):
async def run_shields(
self,
request: RunShieldRequest,
messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in request.shields]
shields = [shield_config_to_shield(c, self.config) for c in shields]
responses = await asyncio.gather(
*[shield.run(request.messages) for shield in shields]
)
responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
return RunShieldResponse(responses=responses)