mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
add safety/list_shields to query available shields
This commit is contained in:
parent
44fe099866
commit
b8914bb56f
3 changed files with 52 additions and 3 deletions
|
@ -13,11 +13,10 @@ import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,6 +61,24 @@ class SafetyClient(Safety):
|
||||||
content = response.json()
|
content = response.json()
|
||||||
return RunShieldResponse(**content)
|
return RunShieldResponse(**content)
|
||||||
|
|
||||||
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/safety/list_shields",
|
||||||
|
json={},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
content = await response.aread()
|
||||||
|
error = f"Error: HTTP {response.status_code} {content.decode()}"
|
||||||
|
cprint(error, "red")
|
||||||
|
raise Exception(error)
|
||||||
|
|
||||||
|
content = response.json()
|
||||||
|
return ListShieldsResponse(**content)
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int):
|
async def run_main(host: str, port: int):
|
||||||
client = SafetyClient(f"http://{host}:{port}")
|
client = SafetyClient(f"http://{host}:{port}")
|
||||||
|
@ -83,6 +100,9 @@ async def run_main(host: str, port: int):
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
response = await client.list_shields()
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int):
|
def main(host: str, port: int):
|
||||||
asyncio.run(run_main(host, port))
|
asyncio.run(run_main(host, port))
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Protocol
|
from typing import Any, Dict, List, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -37,8 +37,16 @@ class RunShieldResponse(BaseModel):
|
||||||
violation: Optional[SafetyViolation] = None
|
violation: Optional[SafetyViolation] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ListShieldsResponse(BaseModel):
|
||||||
|
shields: List[str] = None
|
||||||
|
|
||||||
|
|
||||||
class Safety(Protocol):
|
class Safety(Protocol):
|
||||||
@webmethod(route="/safety/run_shield")
|
@webmethod(route="/safety/run_shield")
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/safety/list_shields")
|
||||||
|
async def list_shields(self) -> ListShieldsResponse: ...
|
||||||
|
|
|
@ -80,6 +80,27 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
|
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
|
supported_sheilds = [
|
||||||
|
v.value for v in MetaReferenceShieldType if self.is_supported(v)
|
||||||
|
]
|
||||||
|
return ListShieldsResponse(shields=supported_sheilds)
|
||||||
|
|
||||||
|
def is_supported(self, typ: MetaReferenceShieldType) -> bool:
|
||||||
|
if typ == MetaReferenceShieldType.llama_guard:
|
||||||
|
return self.config.llama_guard_shield is not None
|
||||||
|
|
||||||
|
if typ == MetaReferenceShieldType.jailbreak_shield:
|
||||||
|
return self.config.prompt_guard_shield is not None
|
||||||
|
|
||||||
|
if typ == MetaReferenceShieldType.injection_shield:
|
||||||
|
return self.config.prompt_guard_shield is not None
|
||||||
|
|
||||||
|
if typ == MetaReferenceShieldType.code_scanner_guard:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||||
cfg = self.config
|
cfg = self.config
|
||||||
if typ == MetaReferenceShieldType.llama_guard:
|
if typ == MetaReferenceShieldType.llama_guard:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue