diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 38e53a438..c88b93a8c 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -1452,6 +1452,40 @@
}
}
]
+ },
+ "delete": {
+ "responses": {
+ "200": {
+ "description": "OK"
+ },
+ "400": {
+ "$ref": "#/components/responses/BadRequest400"
+ },
+ "429": {
+ "$ref": "#/components/responses/TooManyRequests429"
+ },
+ "500": {
+ "$ref": "#/components/responses/InternalServerError500"
+ },
+ "default": {
+ "$ref": "#/components/responses/DefaultError"
+ }
+ },
+ "tags": [
+ "Shields"
+ ],
+ "description": "Unregister a shield.",
+ "parameters": [
+ {
+ "name": "identifier",
+ "in": "path",
+ "description": "The identifier of the shield to unregister.",
+ "required": true,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ]
}
},
"/v1/telemetry/traces/{trace_id}/spans/{span_id}": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 0df60ddf4..d3c322b7c 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -999,6 +999,31 @@ paths:
required: true
schema:
type: string
+ delete:
+ responses:
+ '200':
+ description: OK
+ '400':
+ $ref: '#/components/responses/BadRequest400'
+ '429':
+ $ref: >-
+ #/components/responses/TooManyRequests429
+ '500':
+ $ref: >-
+ #/components/responses/InternalServerError500
+ default:
+ $ref: '#/components/responses/DefaultError'
+ tags:
+ - Shields
+ description: Unregister a shield.
+ parameters:
+ - name: identifier
+ in: path
+ description: >-
+ The identifier of the shield to unregister.
+ required: true
+ schema:
+ type: string
/v1/telemetry/traces/{trace_id}/spans/{span_id}:
get:
responses:
diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py
index ce1f73d8e..e636e3176 100644
--- a/llama_stack/apis/shields/shields.py
+++ b/llama_stack/apis/shields/shields.py
@@ -79,3 +79,11 @@ class Shields(Protocol):
:returns: A Shield.
"""
...
+
+ @webmethod(route="/shields/{identifier:path}", method="DELETE")
+ async def unregister_shield(self, identifier: str) -> None:
+ """Unregister a shield.
+
+ :param identifier: The identifier of the shield to unregister.
+ """
+ ...
diff --git a/llama_stack/cli/llama.py b/llama_stack/cli/llama.py
index 433b311e7..1b5acabc4 100644
--- a/llama_stack/cli/llama.py
+++ b/llama_stack/cli/llama.py
@@ -8,6 +8,7 @@ import argparse
from .download import Download
from .model import ModelParser
+from .shield import ShieldParser
from .stack import StackParser
from .stack.utils import print_subcommand_description
from .verify_download import VerifyDownload
@@ -31,6 +32,7 @@ class LlamaCLIParser:
# Add sub-commands
ModelParser.create(subparsers)
+ ShieldParser.create(subparsers)
StackParser.create(subparsers)
Download.create(subparsers)
VerifyDownload.create(subparsers)
diff --git a/llama_stack/cli/shield/__init__.py b/llama_stack/cli/shield/__init__.py
new file mode 100644
index 000000000..bbc323db0
--- /dev/null
+++ b/llama_stack/cli/shield/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from .shield import ShieldParser
+
+__all__ = ["ShieldParser"]
diff --git a/llama_stack/cli/shield/describe.py b/llama_stack/cli/shield/describe.py
new file mode 100644
index 000000000..9d457a9a6
--- /dev/null
+++ b/llama_stack/cli/shield/describe.py
@@ -0,0 +1,93 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import argparse
+import json
+
+import httpx
+
+from llama_stack.cli.subcommand import Subcommand
+from llama_stack.cli.table import print_table
+
+
+class ShieldDescribe(Subcommand):
+ """Show details about a shield"""
+
+ def __init__(self, subparsers: argparse._SubParsersAction):
+ super().__init__()
+ self.parser = subparsers.add_parser(
+ "describe",
+ prog="llama shield describe",
+ description="Show details about a shield",
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
+ self._add_arguments()
+ self.parser.set_defaults(func=self._run_shield_describe_cmd)
+
+ def _add_arguments(self):
+ self.parser.add_argument(
+ "shield_id",
+ type=str,
+ help="The identifier of the shield to describe",
+ )
+ self.parser.add_argument(
+ "--url",
+ type=str,
+ default="http://localhost:8321",
+ help="URL of the Llama Stack server (default: http://localhost:8321)",
+ )
+ self.parser.add_argument(
+ "--output-format",
+ choices=["table", "json"],
+ default="table",
+ help="Output format (default: table)",
+ )
+
+ def _run_shield_describe_cmd(self, args: argparse.Namespace) -> None:
+ try:
+ response = httpx.get(f"{args.url}/v1/shields/{args.shield_id}")
+ response.raise_for_status()
+
+ shield = response.json()
+
+ if args.output_format == "json":
+ print(json.dumps(shield, indent=2))
+ return
+
+ headers = ["Property", "Value"]
+
+ shield_id = shield.get("identifier", shield.get("shield_id", args.shield_id))
+ provider_id = shield.get("provider_id", "")
+ provider_shield_id = shield.get("provider_resource_id", shield.get("provider_shield_id", ""))
+ resource_type = shield.get("type", "shield")
+
+ rows = [
+ ("Shield ID", shield_id),
+ ("Provider ID", provider_id),
+ ("Provider Shield ID", provider_shield_id),
+ ("Resource Type", resource_type),
+ ]
+
+ if shield.get("params"):
+ rows.append(("Parameters", json.dumps(shield["params"], indent=2)))
+ else:
+ rows.append(("Parameters", ""))
+
+ print_table(
+ rows,
+ headers,
+ separate_rows=True,
+ )
+
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code == 400 or e.response.status_code == 404:
+ print(f"Shield '{args.shield_id}' not found.")
+ else:
+ print(f"HTTP error {e.response.status_code}: {e.response.text}")
+ except httpx.RequestError as e:
+ print(f"Error connecting to Llama Stack server: {e}")
+ except Exception as e:
+ print(f"Error describing shield: {e}")
diff --git a/llama_stack/cli/shield/list.py b/llama_stack/cli/shield/list.py
new file mode 100644
index 000000000..f0696706e
--- /dev/null
+++ b/llama_stack/cli/shield/list.py
@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import argparse
+import json
+
+import httpx
+
+from llama_stack.cli.subcommand import Subcommand
+from llama_stack.cli.table import print_table
+
+
+class ShieldList(Subcommand):
+ """List available shields"""
+
+ def __init__(self, subparsers: argparse._SubParsersAction):
+ super().__init__()
+ self.parser = subparsers.add_parser(
+ "list",
+ prog="llama shield list",
+ description="Show available shields",
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
+ self._add_arguments()
+ self.parser.set_defaults(func=self._run_shield_list_cmd)
+
+ def _add_arguments(self):
+ self.parser.add_argument(
+ "--url",
+ type=str,
+ default="http://localhost:8321",
+ help="URL of the Llama Stack server (default: http://localhost:8321)",
+ )
+ self.parser.add_argument(
+ "--output-format",
+ choices=["table", "json"],
+ default="table",
+ help="Output format (default: table)",
+ )
+
+ def _run_shield_list_cmd(self, args: argparse.Namespace) -> None:
+ try:
+ response = httpx.get(f"{args.url}/v1/shields")
+ response.raise_for_status()
+
+ data = response.json()
+ shields = data.get("data", [])
+
+ if args.output_format == "json":
+ print(json.dumps(shields, indent=2))
+ return
+
+ if not shields:
+ print("No shields found.")
+ return
+
+ headers = ["Shield ID", "Provider ID", "Provider Shield ID", "Parameters"]
+
+ rows = []
+ for shield in shields:
+ params_str = ""
+ if shield.get("params"):
+ params_str = json.dumps(shield["params"], separators=(",", ":"))
+ if len(params_str) > 50:
+ params_str = params_str[:47] + "..."
+
+ rows.append(
+ [
+ shield.get("identifier", shield.get("shield_id", "-")),
+ shield.get("provider_id", "-"),
+ shield.get("provider_resource_id", shield.get("provider_shield_id", "-")),
+ params_str or "-",
+ ]
+ )
+
+ print_table(
+ rows,
+ headers,
+ separate_rows=True,
+ )
+
+ except httpx.RequestError as e:
+ print(f"Error connecting to Llama Stack server: {e}")
+ except httpx.HTTPStatusError as e:
+ print(f"HTTP error {e.response.status_code}: {e.response.text}")
+ except Exception as e:
+ print(f"Error listing shields: {e}")
diff --git a/llama_stack/cli/shield/register.py b/llama_stack/cli/shield/register.py
new file mode 100644
index 000000000..4e6d8b126
--- /dev/null
+++ b/llama_stack/cli/shield/register.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import argparse
+import json
+import sys
+
+import httpx
+
+from llama_stack.cli.subcommand import Subcommand
+
+
+class ShieldRegister(Subcommand):
+ """Register a new shield"""
+
+ def __init__(self, subparsers: argparse._SubParsersAction):
+ super().__init__()
+ self.parser = subparsers.add_parser(
+ "register",
+ prog="llama shield register",
+ description="Register a new shield",
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
+ self._add_arguments()
+ self.parser.set_defaults(func=self._run_shield_register_cmd)
+
+ def _add_arguments(self):
+ self.parser.add_argument(
+ "shield_id",
+ type=str,
+ help="The identifier for the shield to register",
+ )
+ self.parser.add_argument(
+ "--provider-id",
+ type=str,
+ help="The provider ID for the shield",
+ )
+ self.parser.add_argument(
+ "--provider-shield-id",
+ type=str,
+ help="The provider-specific shield identifier",
+ )
+ self.parser.add_argument(
+ "--params",
+ type=str,
+ help='Shield parameters as JSON string (e.g., \'{"key": "value"}\')',
+ )
+ self.parser.add_argument(
+ "--url",
+ type=str,
+ default="http://localhost:8321",
+ help="URL of the Llama Stack server (default: http://localhost:8321)",
+ )
+
+ def _run_shield_register_cmd(self, args: argparse.Namespace) -> None:
+ try:
+ params = None
+ if args.params:
+ try:
+ params = json.loads(args.params)
+ except json.JSONDecodeError as e:
+ print(f"Error parsing parameters JSON: {e}")
+ sys.exit(1)
+
+ payload = {
+ "shield_id": args.shield_id,
+ }
+
+ if args.provider_id:
+ payload["provider_id"] = args.provider_id
+ if args.provider_shield_id:
+ payload["provider_shield_id"] = args.provider_shield_id
+ if params:
+ payload["params"] = params
+
+ response = httpx.post(f"{args.url}/v1/shields", json=payload, headers={"Content-Type": "application/json"})
+ response.raise_for_status()
+
+ shield = response.json()
+
+ print(f" Shield '{shield.get('identifier', args.shield_id)}' registered successfully!")
+ print(f" Provider ID: {shield.get('provider_id', '')}")
+ print(
+ f" Provider Shield ID: {shield.get('provider_resource_id', shield.get('provider_shield_id', ''))}"
+ )
+ if shield.get("params"):
+ print(f" Parameters: {json.dumps(shield['params'], indent=2)}")
+
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code == 400 and "already exists" in e.response.text:
+ print(f"Shield '{args.shield_id}' already exists.")
+ else:
+ print(f"HTTP error {e.response.status_code}: {e.response.text}")
+ sys.exit(1)
+ except httpx.RequestError as e:
+ print(f"Error connecting to Llama Stack server: {e}")
+ sys.exit(1)
+ except Exception as e:
+ print(f"Error registering shield: {e}")
+ sys.exit(1)
diff --git a/llama_stack/cli/shield/shield.py b/llama_stack/cli/shield/shield.py
new file mode 100644
index 000000000..299ea7aab
--- /dev/null
+++ b/llama_stack/cli/shield/shield.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import argparse
+
+from llama_stack.cli.stack.utils import print_subcommand_description
+from llama_stack.cli.subcommand import Subcommand
+
+from .describe import ShieldDescribe
+from .list import ShieldList
+from .register import ShieldRegister
+from .unregister import ShieldUnregister
+
+
+class ShieldParser(Subcommand):
+ """Parser for shield commands"""
+
+ def __init__(self, subparsers: argparse._SubParsersAction):
+ super().__init__()
+ self.parser = subparsers.add_parser(
+ "shield",
+ prog="llama shield",
+ description="Manage safety shields",
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
+
+ self.parser.set_defaults(func=lambda args: self.parser.print_help())
+
+ subparsers = self.parser.add_subparsers(title="shield_subcommands")
+
+ # Add shield sub-commands
+ ShieldList.create(subparsers)
+ ShieldRegister.create(subparsers)
+ ShieldDescribe.create(subparsers)
+ ShieldUnregister.create(subparsers)
+
+ print_subcommand_description(self.parser, subparsers)
diff --git a/llama_stack/cli/shield/unregister.py b/llama_stack/cli/shield/unregister.py
new file mode 100644
index 000000000..09dbc5d8b
--- /dev/null
+++ b/llama_stack/cli/shield/unregister.py
@@ -0,0 +1,91 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import argparse
+import sys
+
+import httpx
+
+from llama_stack.cli.subcommand import Subcommand
+
+
+class ShieldUnregister(Subcommand):
+ """Unregister a shield"""
+
+ def __init__(self, subparsers: argparse._SubParsersAction):
+ super().__init__()
+ self.parser = subparsers.add_parser(
+ "unregister",
+ prog="llama shield unregister",
+ description="Unregister a shield",
+ formatter_class=argparse.RawTextHelpFormatter,
+ )
+ self._add_arguments()
+ self.parser.set_defaults(func=self._run_shield_unregister_cmd)
+
+ def _add_arguments(self):
+ self.parser.add_argument(
+ "shield_id",
+ type=str,
+ help="The identifier of the shield to unregister",
+ )
+ self.parser.add_argument(
+ "--url",
+ type=str,
+ default="http://localhost:8321",
+ help="URL of the Llama Stack server (default: http://localhost:8321)",
+ )
+ self.parser.add_argument(
+ "--force",
+ action="store_true",
+ help="Force unregister without confirmation",
+ )
+
+ def _run_shield_unregister_cmd(self, args: argparse.Namespace) -> None:
+ try:
+ try:
+ response = httpx.get(f"{args.url}/v1/shields/{args.shield_id}")
+ response.raise_for_status()
+ shield = response.json()
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code == 400 or e.response.status_code == 404:
+ print(f"Shield '{args.shield_id}' not found.")
+ sys.exit(1)
+ else:
+ raise
+
+ if not args.force:
+ shield_id = shield.get("identifier", shield.get("shield_id", args.shield_id))
+ provider_id = shield.get("provider_id", "")
+ provider_shield_id = shield.get("provider_resource_id", shield.get("provider_shield_id", ""))
+
+ print("Shield to unregister:")
+ print(f" - Shield ID: {shield_id}")
+ print(f" - Provider ID: {provider_id}")
+ print(f" - Provider Shield ID: {provider_shield_id}")
+
+ response_input = input(f"\nAre you sure you want to unregister shield '{args.shield_id}'? (y/N): ")
+ if response_input.lower() not in ["y", "yes"]:
+ print("Unregister cancelled.")
+ return
+
+ response = httpx.delete(f"{args.url}/v1/shields/{args.shield_id}")
+ response.raise_for_status()
+
+ print(f"Shield '{args.shield_id}' unregistered successfully!")
+
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code == 400 or e.response.status_code == 404:
+ print(f"Shield '{args.shield_id}' not found.")
+ else:
+ print(f"HTTP error {e.response.status_code}: {e.response.text}")
+ sys.exit(1)
+ except httpx.RequestError as e:
+ print(f"Error connecting to Llama Stack server: {e}")
+ sys.exit(1)
+ except Exception as e:
+ print(f"Error unregistering shield: {e}")
+ sys.exit(1)
diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py
index 26ee8e722..f4273c7b5 100644
--- a/llama_stack/distribution/routers/safety.py
+++ b/llama_stack/distribution/routers/safety.py
@@ -43,6 +43,10 @@ class SafetyRouter(Safety):
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
+ async def unregister_shield(self, identifier: str) -> None:
+ logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
+ return await self.routing_table.unregister_shield(identifier)
+
async def run_shield(
self,
shield_id: str,
diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py
index 2f6ac90bb..4bf1b9756 100644
--- a/llama_stack/distribution/routing_tables/common.py
+++ b/llama_stack/distribution/routing_tables/common.py
@@ -59,6 +59,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
+ elif api == Api.safety:
+ return await p.unregister_shield(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/distribution/routing_tables/shields.py
index 5215981b9..bd2b64453 100644
--- a/llama_stack/distribution/routing_tables/shields.py
+++ b/llama_stack/distribution/routing_tables/shields.py
@@ -55,3 +55,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
)
await self.register_object(shield)
return shield
+
+ async def unregister_shield(self, identifier: str) -> None:
+ existing_shield = await self.get_shield(identifier)
+ if existing_shield is None:
+ raise ValueError(f"Shield '{identifier}' not found")
+ logger.info(f"Shield {identifier} was unregistered successfully.")
+ await self.unregister_object(existing_shield)
diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py
index 424380324..1978a2b54 100644
--- a/llama_stack/providers/datatypes.py
+++ b/llama_stack/providers/datatypes.py
@@ -51,6 +51,8 @@ class ModelsProtocolPrivate(Protocol):
class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...
+ async def unregister_shield(self, identifier: str) -> None: ...
+
class VectorDBsProtocolPrivate(Protocol):
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
index 9d359e053..dc0474e5d 100644
--- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
+++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
@@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
# The model will be validated during runtime when making inference calls
pass
+ async def unregister_shield(self, identifier: str) -> None:
+ # LlamaGuard doesn't need to do anything special for unregistration
+ # The routing table handles the removal from the registry
+ pass
+
async def run_shield(
self,
shield_id: str,
diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
index ff87889ea..d7a30d212 100644
--- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
+++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
@@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
+ async def unregister_shield(self, identifier: str) -> None:
+ pass
+
async def run_shield(
self,
shield_id: str,
diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py
index c43b51073..1895e7507 100644
--- a/llama_stack/providers/remote/safety/bedrock/bedrock.py
+++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py
@@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
)
+ async def unregister_shield(self, identifier: str) -> None:
+ pass
+
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
) -> RunShieldResponse:
diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py
index 411badb1c..7f17b1cb6 100644
--- a/llama_stack/providers/remote/safety/nvidia/nvidia.py
+++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py
@@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
+ async def unregister_shield(self, identifier: str) -> None:
+ pass
+
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse:
diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py
index 1a65f6aa1..e917b8c28 100644
--- a/llama_stack/providers/remote/safety/sambanova/sambanova.py
+++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py
@@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
):
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
+ async def unregister_shield(self, identifier: str) -> None:
+ pass
+
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse:
diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py
index 12b05ebff..8968640a0 100644
--- a/tests/unit/distribution/routers/test_routing_tables.py
+++ b/tests/unit/distribution/routers/test_routing_tables.py
@@ -8,6 +8,8 @@
from unittest.mock import AsyncMock
+import pytest
+
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
@@ -53,6 +55,9 @@ class SafetyImpl(Impl):
async def register_shield(self, shield: Shield):
return shield
+ async def unregister_shield(self, shield_id: str):
+ return shield_id
+
class DatasetsImpl(Impl):
def __init__(self):
@@ -166,12 +171,42 @@ async def test_shields_routing_table(cached_disk_dist_registry):
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
shields = await table.list_shields()
-
assert len(shields.data) == 2
+
shield_ids = {s.identifier for s in shields.data}
assert "test-shield" in shield_ids
assert "test-shield-2" in shield_ids
+ # Test get specific shield
+ test_shield = await table.get_shield(identifier="test-shield")
+ assert test_shield is not None
+ assert test_shield.identifier == "test-shield"
+ assert test_shield.provider_id == "test_provider"
+ assert test_shield.provider_resource_id == "test-shield"
+ assert test_shield.params == {}
+
+ # Test get non-existent shield - should raise ValueError with specific message
+ with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
+ await table.get_shield(identifier="non-existent")
+
+ # Test unregistering shields
+ await table.unregister_shield(identifier="test-shield")
+ shields = await table.list_shields()
+
+ assert len(shields.data) == 1
+ shield_ids = {s.identifier for s in shields.data}
+ assert "test-shield" not in shield_ids
+ assert "test-shield-2" in shield_ids
+
+ # Unregister the remaining shield
+ await table.unregister_shield(identifier="test-shield-2")
+ shields = await table.list_shields()
+ assert len(shields.data) == 0
+
+ # Test unregistering non-existent shield - should raise ValueError with specific message
+ with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
+ await table.unregister_shield(identifier="non-existent")
+
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})