mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
safety api
This commit is contained in:
parent
959c499cac
commit
4b34f741d0
3 changed files with 9 additions and 11 deletions
|
@ -86,5 +86,6 @@ class Safety(Protocol):
|
|||
@webmethod(route="/safety/run_shields")
|
||||
async def run_shields(
|
||||
self,
|
||||
request: RunShieldRequest,
|
||||
messages: List[Message],
|
||||
shields: List[ShieldDefinition],
|
||||
) -> RunShieldResponse: ...
|
||||
|
|
|
@ -13,10 +13,10 @@ import fire
|
|||
import httpx
|
||||
|
||||
from llama_models.llama3.api.datatypes import UserMessage
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from .api import * # noqa: F403
|
||||
|
||||
|
@ -43,9 +43,7 @@ class SafetyClient(Safety):
|
|||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/safety/run_shields",
|
||||
json={
|
||||
"request": encodable_dict(request),
|
||||
},
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
)
|
||||
|
|
|
@ -52,13 +52,12 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
|
||||
async def run_shields(
|
||||
self,
|
||||
request: RunShieldRequest,
|
||||
messages: List[Message],
|
||||
shields: List[ShieldDefinition],
|
||||
) -> RunShieldResponse:
|
||||
shields = [shield_config_to_shield(c, self.config) for c in request.shields]
|
||||
shields = [shield_config_to_shield(c, self.config) for c in shields]
|
||||
|
||||
responses = await asyncio.gather(
|
||||
*[shield.run(request.messages) for shield in shields]
|
||||
)
|
||||
responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
|
||||
|
||||
return RunShieldResponse(responses=responses)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue