diff --git a/llama_stack/apis/safety/client.py b/llama_stack/apis/safety/client.py index 6e74efd8f..29bb94420 100644 --- a/llama_stack/apis/safety/client.py +++ b/llama_stack/apis/safety/client.py @@ -13,10 +13,11 @@ import fire import httpx from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.distribution.datatypes import RemoteProviderConfig from pydantic import BaseModel from termcolor import cprint +from llama_stack.distribution.datatypes import RemoteProviderConfig + from llama_stack.apis.safety import * # noqa: F403 @@ -61,24 +62,6 @@ class SafetyClient(Safety): content = response.json() return RunShieldResponse(**content) - async def list_shields(self) -> ListShieldsResponse: - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.base_url}/safety/list_shields", - json={}, - headers={"Content-Type": "application/json"}, - timeout=20, - ) - - if response.status_code != 200: - content = await response.aread() - error = f"Error: HTTP {response.status_code} {content.decode()}" - cprint(error, "red") - raise Exception(error) - - content = response.json() - return ListShieldsResponse(**content) - async def run_main(host: str, port: int): client = SafetyClient(f"http://{host}:{port}") @@ -100,9 +83,6 @@ async def run_main(host: str, port: int): ) print(response) - response = await client.list_shields() - print(response) - def main(host: str, port: int): asyncio.run(run_main(host, port)) diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 468bf48de..f855bd1a7 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -37,16 +37,8 @@ class RunShieldResponse(BaseModel): violation: Optional[SafetyViolation] = None -@json_schema_type -class ListShieldsResponse(BaseModel): - shields: List[str] = None - - class Safety(Protocol): @webmethod(route="/safety/run_shield") async def run_shield( self, shield: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... - - @webmethod(route="/safety/list_shields") - async def list_shields(self) -> ListShieldsResponse: ... diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 28b3eed78..6eccf47a5 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -80,27 +80,6 @@ class MetaReferenceSafetyImpl(Safety): return RunShieldResponse(violation=violation) - async def list_shields(self) -> ListShieldsResponse: - supported_sheilds = [ - v.value for v in MetaReferenceShieldType if self.is_supported(v) - ] - return ListShieldsResponse(shields=supported_sheilds) - - def is_supported(self, typ: MetaReferenceShieldType) -> bool: - if typ == MetaReferenceShieldType.llama_guard: - return self.config.llama_guard_shield is not None - - if typ == MetaReferenceShieldType.jailbreak_shield: - return self.config.prompt_guard_shield is not None - - if typ == MetaReferenceShieldType.injection_shield: - return self.config.prompt_guard_shield is not None - - if typ == MetaReferenceShieldType.code_scanner_guard: - return True - - return False - def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: cfg = self.config if typ == MetaReferenceShieldType.llama_guard: