From b8914bb56fee5a713e7ea33175c21583f9848a12 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sun, 22 Sep 2024 01:08:32 -0700 Subject: [PATCH] add safety/list_shields to query available shields --- llama_stack/apis/safety/client.py | 24 +++++++++++++++++-- llama_stack/apis/safety/safety.py | 10 +++++++- .../impls/meta_reference/safety/safety.py | 21 ++++++++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/llama_stack/apis/safety/client.py b/llama_stack/apis/safety/client.py index 29bb94420..6e74efd8f 100644 --- a/llama_stack/apis/safety/client.py +++ b/llama_stack/apis/safety/client.py @@ -13,11 +13,10 @@ import fire import httpx from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import RemoteProviderConfig from pydantic import BaseModel from termcolor import cprint -from llama_stack.distribution.datatypes import RemoteProviderConfig - from llama_stack.apis.safety import * # noqa: F403 @@ -62,6 +61,24 @@ class SafetyClient(Safety): content = response.json() return RunShieldResponse(**content) + async def list_shields(self) -> ListShieldsResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/safety/list_shields", + json={}, + headers={"Content-Type": "application/json"}, + timeout=20, + ) + + if response.status_code != 200: + content = await response.aread() + error = f"Error: HTTP {response.status_code} {content.decode()}" + cprint(error, "red") + raise Exception(error) + + content = response.json() + return ListShieldsResponse(**content) + async def run_main(host: str, port: int): client = SafetyClient(f"http://{host}:{port}") @@ -83,6 +100,9 @@ async def run_main(host: str, port: int): ) print(response) + response = await client.list_shields() + print(response) + def main(host: str, port: int): asyncio.run(run_main(host, port)) diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index cb8eb3c4a..468bf48de 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, Protocol +from typing import Any, Dict, List, Protocol from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel @@ -37,8 +37,16 @@ class RunShieldResponse(BaseModel): violation: Optional[SafetyViolation] = None +@json_schema_type +class ListShieldsResponse(BaseModel): + shields: List[str] = None + + class Safety(Protocol): @webmethod(route="/safety/run_shield") async def run_shield( self, shield: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... + + @webmethod(route="/safety/list_shields") + async def list_shields(self) -> ListShieldsResponse: ... diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 6eccf47a5..28b3eed78 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -80,6 +80,27 @@ class MetaReferenceSafetyImpl(Safety): return RunShieldResponse(violation=violation) + async def list_shields(self) -> ListShieldsResponse: + supported_sheilds = [ + v.value for v in MetaReferenceShieldType if self.is_supported(v) + ] + return ListShieldsResponse(shields=supported_sheilds) + + def is_supported(self, typ: MetaReferenceShieldType) -> bool: + if typ == MetaReferenceShieldType.llama_guard: + return self.config.llama_guard_shield is not None + + if typ == MetaReferenceShieldType.jailbreak_shield: + return self.config.prompt_guard_shield is not None + + if typ == MetaReferenceShieldType.injection_shield: + return self.config.prompt_guard_shield is not None + + if typ == MetaReferenceShieldType.code_scanner_guard: + return True + + return False + def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: cfg = self.config if typ == MetaReferenceShieldType.llama_guard: