From 93be3aa9b6b4bab9eb5b835defb1dae4653592e2 Mon Sep 17 00:00:00 2001 From: r3v5 Date: Tue, 22 Jul 2025 12:28:18 +0100 Subject: [PATCH] feat: create unregister shield API endpoint and CLI for shields management --- docs/_static/llama-stack-spec.html | 34 ++++++ docs/_static/llama-stack-spec.yaml | 25 +++++ llama_stack/apis/shields/shields.py | 8 ++ llama_stack/cli/llama.py | 2 + llama_stack/cli/shield/__init__.py | 9 ++ llama_stack/cli/shield/describe.py | 93 ++++++++++++++++ llama_stack/cli/shield/list.py | 90 +++++++++++++++ llama_stack/cli/shield/register.py | 103 ++++++++++++++++++ llama_stack/cli/shield/shield.py | 40 +++++++ llama_stack/cli/shield/unregister.py | 91 ++++++++++++++++ llama_stack/distribution/routers/safety.py | 4 + .../distribution/routing_tables/common.py | 2 + .../distribution/routing_tables/shields.py | 7 ++ llama_stack/providers/datatypes.py | 2 + .../inline/safety/llama_guard/llama_guard.py | 5 + .../safety/prompt_guard/prompt_guard.py | 3 + .../remote/safety/bedrock/bedrock.py | 3 + .../providers/remote/safety/nvidia/nvidia.py | 3 + .../remote/safety/sambanova/sambanova.py | 3 + .../routers/test_routing_tables.py | 37 ++++++- 20 files changed, 563 insertions(+), 1 deletion(-) create mode 100644 llama_stack/cli/shield/__init__.py create mode 100644 llama_stack/cli/shield/describe.py create mode 100644 llama_stack/cli/shield/list.py create mode 100644 llama_stack/cli/shield/register.py create mode 100644 llama_stack/cli/shield/shield.py create mode 100644 llama_stack/cli/shield/unregister.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 38e53a438..c88b93a8c 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -1452,6 +1452,40 @@ } } ] + }, + "delete": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Shields" + ], + "description": "Unregister a shield.", + "parameters": [ + { + "name": "identifier", + "in": "path", + "description": "The identifier of the shield to unregister.", + "required": true, + "schema": { + "type": "string" + } + } + ] } }, "/v1/telemetry/traces/{trace_id}/spans/{span_id}": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 0df60ddf4..d3c322b7c 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -999,6 +999,31 @@ paths: required: true schema: type: string + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Shields + description: Unregister a shield. + parameters: + - name: identifier + in: path + description: >- + The identifier of the shield to unregister. + required: true + schema: + type: string /v1/telemetry/traces/{trace_id}/spans/{span_id}: get: responses: diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index ce1f73d8e..e636e3176 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -79,3 +79,11 @@ class Shields(Protocol): :returns: A Shield. """ ... + + @webmethod(route="/shields/{identifier:path}", method="DELETE") + async def unregister_shield(self, identifier: str) -> None: + """Unregister a shield. + + :param identifier: The identifier of the shield to unregister. + """ + ... diff --git a/llama_stack/cli/llama.py b/llama_stack/cli/llama.py index 433b311e7..1b5acabc4 100644 --- a/llama_stack/cli/llama.py +++ b/llama_stack/cli/llama.py @@ -8,6 +8,7 @@ import argparse from .download import Download from .model import ModelParser +from .shield import ShieldParser from .stack import StackParser from .stack.utils import print_subcommand_description from .verify_download import VerifyDownload @@ -31,6 +32,7 @@ class LlamaCLIParser: # Add sub-commands ModelParser.create(subparsers) + ShieldParser.create(subparsers) StackParser.create(subparsers) Download.create(subparsers) VerifyDownload.create(subparsers) diff --git a/llama_stack/cli/shield/__init__.py b/llama_stack/cli/shield/__init__.py new file mode 100644 index 000000000..bbc323db0 --- /dev/null +++ b/llama_stack/cli/shield/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .shield import ShieldParser + +__all__ = ["ShieldParser"] diff --git a/llama_stack/cli/shield/describe.py b/llama_stack/cli/shield/describe.py new file mode 100644 index 000000000..9d457a9a6 --- /dev/null +++ b/llama_stack/cli/shield/describe.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import json + +import httpx + +from llama_stack.cli.subcommand import Subcommand +from llama_stack.cli.table import print_table + + +class ShieldDescribe(Subcommand): + """Show details about a shield""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "describe", + prog="llama shield describe", + description="Show details about a shield", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_shield_describe_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "shield_id", + type=str, + help="The identifier of the shield to describe", + ) + self.parser.add_argument( + "--url", + type=str, + default="http://localhost:8321", + help="URL of the Llama Stack server (default: http://localhost:8321)", + ) + self.parser.add_argument( + "--output-format", + choices=["table", "json"], + default="table", + help="Output format (default: table)", + ) + + def _run_shield_describe_cmd(self, args: argparse.Namespace) -> None: + try: + response = httpx.get(f"{args.url}/v1/shields/{args.shield_id}") + response.raise_for_status() + + shield = response.json() + + if args.output_format == "json": + print(json.dumps(shield, indent=2)) + return + + headers = ["Property", "Value"] + + shield_id = shield.get("identifier", shield.get("shield_id", args.shield_id)) + provider_id = shield.get("provider_id", "") + provider_shield_id = shield.get("provider_resource_id", shield.get("provider_shield_id", "")) + resource_type = shield.get("type", "shield") + + rows = [ + ("Shield ID", shield_id), + ("Provider ID", provider_id), + ("Provider Shield ID", provider_shield_id), + ("Resource Type", resource_type), + ] + + if shield.get("params"): + rows.append(("Parameters", json.dumps(shield["params"], indent=2))) + else: + rows.append(("Parameters", "")) + + print_table( + rows, + headers, + separate_rows=True, + ) + + except httpx.HTTPStatusError as e: + if e.response.status_code == 400 or e.response.status_code == 404: + print(f"Shield '{args.shield_id}' not found.") + else: + print(f"HTTP error {e.response.status_code}: {e.response.text}") + except httpx.RequestError as e: + print(f"Error connecting to Llama Stack server: {e}") + except Exception as e: + print(f"Error describing shield: {e}") diff --git a/llama_stack/cli/shield/list.py b/llama_stack/cli/shield/list.py new file mode 100644 index 000000000..f0696706e --- /dev/null +++ b/llama_stack/cli/shield/list.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import json + +import httpx + +from llama_stack.cli.subcommand import Subcommand +from llama_stack.cli.table import print_table + + +class ShieldList(Subcommand): + """List available shields""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "list", + prog="llama shield list", + description="Show available shields", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_shield_list_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "--url", + type=str, + default="http://localhost:8321", + help="URL of the Llama Stack server (default: http://localhost:8321)", + ) + self.parser.add_argument( + "--output-format", + choices=["table", "json"], + default="table", + help="Output format (default: table)", + ) + + def _run_shield_list_cmd(self, args: argparse.Namespace) -> None: + try: + response = httpx.get(f"{args.url}/v1/shields") + response.raise_for_status() + + data = response.json() + shields = data.get("data", []) + + if args.output_format == "json": + print(json.dumps(shields, indent=2)) + return + + if not shields: + print("No shields found.") + return + + headers = ["Shield ID", "Provider ID", "Provider Shield ID", "Parameters"] + + rows = [] + for shield in shields: + params_str = "" + if shield.get("params"): + params_str = json.dumps(shield["params"], separators=(",", ":")) + if len(params_str) > 50: + params_str = params_str[:47] + "..." + + rows.append( + [ + shield.get("identifier", shield.get("shield_id", "-")), + shield.get("provider_id", "-"), + shield.get("provider_resource_id", shield.get("provider_shield_id", "-")), + params_str or "-", + ] + ) + + print_table( + rows, + headers, + separate_rows=True, + ) + + except httpx.RequestError as e: + print(f"Error connecting to Llama Stack server: {e}") + except httpx.HTTPStatusError as e: + print(f"HTTP error {e.response.status_code}: {e.response.text}") + except Exception as e: + print(f"Error listing shields: {e}") diff --git a/llama_stack/cli/shield/register.py b/llama_stack/cli/shield/register.py new file mode 100644 index 000000000..4e6d8b126 --- /dev/null +++ b/llama_stack/cli/shield/register.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import json +import sys + +import httpx + +from llama_stack.cli.subcommand import Subcommand + + +class ShieldRegister(Subcommand): + """Register a new shield""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "register", + prog="llama shield register", + description="Register a new shield", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_shield_register_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "shield_id", + type=str, + help="The identifier for the shield to register", + ) + self.parser.add_argument( + "--provider-id", + type=str, + help="The provider ID for the shield", + ) + self.parser.add_argument( + "--provider-shield-id", + type=str, + help="The provider-specific shield identifier", + ) + self.parser.add_argument( + "--params", + type=str, + help='Shield parameters as JSON string (e.g., \'{"key": "value"}\')', + ) + self.parser.add_argument( + "--url", + type=str, + default="http://localhost:8321", + help="URL of the Llama Stack server (default: http://localhost:8321)", + ) + + def _run_shield_register_cmd(self, args: argparse.Namespace) -> None: + try: + params = None + if args.params: + try: + params = json.loads(args.params) + except json.JSONDecodeError as e: + print(f"Error parsing parameters JSON: {e}") + sys.exit(1) + + payload = { + "shield_id": args.shield_id, + } + + if args.provider_id: + payload["provider_id"] = args.provider_id + if args.provider_shield_id: + payload["provider_shield_id"] = args.provider_shield_id + if params: + payload["params"] = params + + response = httpx.post(f"{args.url}/v1/shields", json=payload, headers={"Content-Type": "application/json"}) + response.raise_for_status() + + shield = response.json() + + print(f" Shield '{shield.get('identifier', args.shield_id)}' registered successfully!") + print(f" Provider ID: {shield.get('provider_id', '')}") + print( + f" Provider Shield ID: {shield.get('provider_resource_id', shield.get('provider_shield_id', ''))}" + ) + if shield.get("params"): + print(f" Parameters: {json.dumps(shield['params'], indent=2)}") + + except httpx.HTTPStatusError as e: + if e.response.status_code == 400 and "already exists" in e.response.text: + print(f"Shield '{args.shield_id}' already exists.") + else: + print(f"HTTP error {e.response.status_code}: {e.response.text}") + sys.exit(1) + except httpx.RequestError as e: + print(f"Error connecting to Llama Stack server: {e}") + sys.exit(1) + except Exception as e: + print(f"Error registering shield: {e}") + sys.exit(1) diff --git a/llama_stack/cli/shield/shield.py b/llama_stack/cli/shield/shield.py new file mode 100644 index 000000000..299ea7aab --- /dev/null +++ b/llama_stack/cli/shield/shield.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from llama_stack.cli.stack.utils import print_subcommand_description +from llama_stack.cli.subcommand import Subcommand + +from .describe import ShieldDescribe +from .list import ShieldList +from .register import ShieldRegister +from .unregister import ShieldUnregister + + +class ShieldParser(Subcommand): + """Parser for shield commands""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "shield", + prog="llama shield", + description="Manage safety shields", + formatter_class=argparse.RawTextHelpFormatter, + ) + + self.parser.set_defaults(func=lambda args: self.parser.print_help()) + + subparsers = self.parser.add_subparsers(title="shield_subcommands") + + # Add shield sub-commands + ShieldList.create(subparsers) + ShieldRegister.create(subparsers) + ShieldDescribe.create(subparsers) + ShieldUnregister.create(subparsers) + + print_subcommand_description(self.parser, subparsers) diff --git a/llama_stack/cli/shield/unregister.py b/llama_stack/cli/shield/unregister.py new file mode 100644 index 000000000..09dbc5d8b --- /dev/null +++ b/llama_stack/cli/shield/unregister.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import sys + +import httpx + +from llama_stack.cli.subcommand import Subcommand + + +class ShieldUnregister(Subcommand): + """Unregister a shield""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "unregister", + prog="llama shield unregister", + description="Unregister a shield", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_shield_unregister_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "shield_id", + type=str, + help="The identifier of the shield to unregister", + ) + self.parser.add_argument( + "--url", + type=str, + default="http://localhost:8321", + help="URL of the Llama Stack server (default: http://localhost:8321)", + ) + self.parser.add_argument( + "--force", + action="store_true", + help="Force unregister without confirmation", + ) + + def _run_shield_unregister_cmd(self, args: argparse.Namespace) -> None: + try: + try: + response = httpx.get(f"{args.url}/v1/shields/{args.shield_id}") + response.raise_for_status() + shield = response.json() + except httpx.HTTPStatusError as e: + if e.response.status_code == 400 or e.response.status_code == 404: + print(f"Shield '{args.shield_id}' not found.") + sys.exit(1) + else: + raise + + if not args.force: + shield_id = shield.get("identifier", shield.get("shield_id", args.shield_id)) + provider_id = shield.get("provider_id", "") + provider_shield_id = shield.get("provider_resource_id", shield.get("provider_shield_id", "")) + + print("Shield to unregister:") + print(f" - Shield ID: {shield_id}") + print(f" - Provider ID: {provider_id}") + print(f" - Provider Shield ID: {provider_shield_id}") + + response_input = input(f"\nAre you sure you want to unregister shield '{args.shield_id}'? (y/N): ") + if response_input.lower() not in ["y", "yes"]: + print("Unregister cancelled.") + return + + response = httpx.delete(f"{args.url}/v1/shields/{args.shield_id}") + response.raise_for_status() + + print(f"Shield '{args.shield_id}' unregistered successfully!") + + except httpx.HTTPStatusError as e: + if e.response.status_code == 400 or e.response.status_code == 404: + print(f"Shield '{args.shield_id}' not found.") + else: + print(f"HTTP error {e.response.status_code}: {e.response.text}") + sys.exit(1) + except httpx.RequestError as e: + print(f"Error connecting to Llama Stack server: {e}") + sys.exit(1) + except Exception as e: + print(f"Error unregistering shield: {e}") + sys.exit(1) diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py index 26ee8e722..f4273c7b5 100644 --- a/llama_stack/distribution/routers/safety.py +++ b/llama_stack/distribution/routers/safety.py @@ -43,6 +43,10 @@ class SafetyRouter(Safety): logger.debug(f"SafetyRouter.register_shield: {shield_id}") return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + async def unregister_shield(self, identifier: str) -> None: + logger.debug(f"SafetyRouter.unregister_shield: {identifier}") + return await self.routing_table.unregister_shield(identifier) + async def run_shield( self, shield_id: str, diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 2f6ac90bb..4bf1b9756 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -59,6 +59,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: return await p.unregister_vector_db(obj.identifier) elif api == Api.inference: return await p.unregister_model(obj.identifier) + elif api == Api.safety: + return await p.unregister_shield(obj.identifier) elif api == Api.datasetio: return await p.unregister_dataset(obj.identifier) elif api == Api.tool_runtime: diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/distribution/routing_tables/shields.py index 5215981b9..bd2b64453 100644 --- a/llama_stack/distribution/routing_tables/shields.py +++ b/llama_stack/distribution/routing_tables/shields.py @@ -55,3 +55,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): ) await self.register_object(shield) return shield + + async def unregister_shield(self, identifier: str) -> None: + existing_shield = await self.get_shield(identifier) + if existing_shield is None: + raise ValueError(f"Shield '{identifier}' not found") + logger.info(f"Shield {identifier} was unregistered successfully.") + await self.unregister_object(existing_shield) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 424380324..1978a2b54 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -51,6 +51,8 @@ class ModelsProtocolPrivate(Protocol): class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... + async def unregister_shield(self, identifier: str) -> None: ... + class VectorDBsProtocolPrivate(Protocol): async def register_vector_db(self, vector_db: VectorDB) -> None: ... diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 9d359e053..dc0474e5d 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): # The model will be validated during runtime when making inference calls pass + async def unregister_shield(self, identifier: str) -> None: + # LlamaGuard doesn't need to do anything special for unregistration + # The routing table handles the removal from the registry + pass + async def run_shield( self, shield_id: str, diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index ff87889ea..d7a30d212 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if shield.provider_resource_id != PROMPT_GUARD_MODEL: raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index c43b51073..1895e7507 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock" ) + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 411badb1c..7f17b1cb6 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): if not shield.provider_resource_id: raise ValueError("Shield model not provided.") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 1a65f6aa1..e917b8c28 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide ): logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None ) -> RunShieldResponse: diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 12b05ebff..8968640a0 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -8,6 +8,8 @@ from unittest.mock import AsyncMock +import pytest + from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api @@ -53,6 +55,9 @@ class SafetyImpl(Impl): async def register_shield(self, shield: Shield): return shield + async def unregister_shield(self, shield_id: str): + return shield_id + class DatasetsImpl(Impl): def __init__(self): @@ -166,12 +171,42 @@ async def test_shields_routing_table(cached_disk_dist_registry): await table.register_shield(shield_id="test-shield", provider_id="test_provider") await table.register_shield(shield_id="test-shield-2", provider_id="test_provider") shields = await table.list_shields() - assert len(shields.data) == 2 + shield_ids = {s.identifier for s in shields.data} assert "test-shield" in shield_ids assert "test-shield-2" in shield_ids + # Test get specific shield + test_shield = await table.get_shield(identifier="test-shield") + assert test_shield is not None + assert test_shield.identifier == "test-shield" + assert test_shield.provider_id == "test_provider" + assert test_shield.provider_resource_id == "test-shield" + assert test_shield.params == {} + + # Test get non-existent shield - should raise ValueError with specific message + with pytest.raises(ValueError, match="Shield 'non-existent' not found"): + await table.get_shield(identifier="non-existent") + + # Test unregistering shields + await table.unregister_shield(identifier="test-shield") + shields = await table.list_shields() + + assert len(shields.data) == 1 + shield_ids = {s.identifier for s in shields.data} + assert "test-shield" not in shield_ids + assert "test-shield-2" in shield_ids + + # Unregister the remaining shield + await table.unregister_shield(identifier="test-shield-2") + shields = await table.list_shields() + assert len(shields.data) == 0 + + # Test unregistering non-existent shield - should raise ValueError with specific message + with pytest.raises(ValueError, match="Shield 'non-existent' not found"): + await table.unregister_shield(identifier="non-existent") + async def test_vectordbs_routing_table(cached_disk_dist_registry): table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})