mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
safety api
This commit is contained in:
parent
959c499cac
commit
4b34f741d0
3 changed files with 9 additions and 11 deletions
|
@ -86,5 +86,6 @@ class Safety(Protocol):
|
||||||
@webmethod(route="/safety/run_shields")
|
@webmethod(route="/safety/run_shields")
|
||||||
async def run_shields(
|
async def run_shields(
|
||||||
self,
|
self,
|
||||||
request: RunShieldRequest,
|
messages: List[Message],
|
||||||
|
shields: List[ShieldDefinition],
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse: ...
|
||||||
|
|
|
@ -13,10 +13,10 @@ import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import UserMessage
|
from llama_models.llama3.api.datatypes import UserMessage
|
||||||
from pydantic import BaseModel
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
from .api import * # noqa: F403
|
||||||
|
|
||||||
|
@ -43,9 +43,7 @@ class SafetyClient(Safety):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/safety/run_shields",
|
f"{self.base_url}/safety/run_shields",
|
||||||
json={
|
json=encodable_dict(request),
|
||||||
"request": encodable_dict(request),
|
|
||||||
},
|
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
)
|
)
|
||||||
|
|
|
@ -52,13 +52,12 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
|
|
||||||
async def run_shields(
|
async def run_shields(
|
||||||
self,
|
self,
|
||||||
request: RunShieldRequest,
|
messages: List[Message],
|
||||||
|
shields: List[ShieldDefinition],
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
shields = [shield_config_to_shield(c, self.config) for c in request.shields]
|
shields = [shield_config_to_shield(c, self.config) for c in shields]
|
||||||
|
|
||||||
responses = await asyncio.gather(
|
responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
|
||||||
*[shield.run(request.messages) for shield in shields]
|
|
||||||
)
|
|
||||||
|
|
||||||
return RunShieldResponse(responses=responses)
|
return RunShieldResponse(responses=responses)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue