diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 38e53a438..c88b93a8c 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -1452,6 +1452,40 @@
}
}
]
+ },
+ "delete": {
+ "responses": {
+ "200": {
+ "description": "OK"
+ },
+ "400": {
+ "$ref": "#/components/responses/BadRequest400"
+ },
+ "429": {
+ "$ref": "#/components/responses/TooManyRequests429"
+ },
+ "500": {
+ "$ref": "#/components/responses/InternalServerError500"
+ },
+ "default": {
+ "$ref": "#/components/responses/DefaultError"
+ }
+ },
+ "tags": [
+ "Shields"
+ ],
+ "description": "Unregister a shield.",
+ "parameters": [
+ {
+ "name": "identifier",
+ "in": "path",
+ "description": "The identifier of the shield to unregister.",
+ "required": true,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ]
}
},
"/v1/telemetry/traces/{trace_id}/spans/{span_id}": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 0df60ddf4..d3c322b7c 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -999,6 +999,31 @@ paths:
required: true
schema:
type: string
+ delete:
+ responses:
+ '200':
+ description: OK
+ '400':
+ $ref: '#/components/responses/BadRequest400'
+ '429':
+ $ref: >-
+ #/components/responses/TooManyRequests429
+ '500':
+ $ref: >-
+ #/components/responses/InternalServerError500
+ default:
+ $ref: '#/components/responses/DefaultError'
+ tags:
+ - Shields
+ description: Unregister a shield.
+ parameters:
+ - name: identifier
+ in: path
+ description: >-
+ The identifier of the shield to unregister.
+ required: true
+ schema:
+ type: string
/v1/telemetry/traces/{trace_id}/spans/{span_id}:
get:
responses:
diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py
index ce1f73d8e..e636e3176 100644
--- a/llama_stack/apis/shields/shields.py
+++ b/llama_stack/apis/shields/shields.py
@@ -79,3 +79,11 @@ class Shields(Protocol):
:returns: A Shield.
"""
...
+
+ @webmethod(route="/shields/{identifier:path}", method="DELETE")
+ async def unregister_shield(self, identifier: str) -> None:
+ """Unregister a shield.
+
+ :param identifier: The identifier of the shield to unregister.
+ """
+ ...
diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py
index 26ee8e722..f4273c7b5 100644
--- a/llama_stack/distribution/routers/safety.py
+++ b/llama_stack/distribution/routers/safety.py
@@ -43,6 +43,10 @@ class SafetyRouter(Safety):
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
+ async def unregister_shield(self, identifier: str) -> None:
+ logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
+ return await self.routing_table.unregister_shield(identifier)
+
async def run_shield(
self,
shield_id: str,
diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py
index caf0780fd..2e51db093 100644
--- a/llama_stack/distribution/routing_tables/common.py
+++ b/llama_stack/distribution/routing_tables/common.py
@@ -59,6 +59,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
+ elif api == Api.safety:
+ return await p.unregister_shield(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/distribution/routing_tables/shields.py
index 5215981b9..bd2b64453 100644
--- a/llama_stack/distribution/routing_tables/shields.py
+++ b/llama_stack/distribution/routing_tables/shields.py
@@ -55,3 +55,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
)
await self.register_object(shield)
return shield
+
+ async def unregister_shield(self, identifier: str) -> None:
+ existing_shield = await self.get_shield(identifier)
+ if existing_shield is None:
+ raise ValueError(f"Shield '{identifier}' not found")
+ logger.info(f"Shield {identifier} was unregistered successfully.")
+ await self.unregister_object(existing_shield)
diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py
index 005bfbab8..d500e6261 100644
--- a/llama_stack/providers/datatypes.py
+++ b/llama_stack/providers/datatypes.py
@@ -62,6 +62,8 @@ class ModelsProtocolPrivate(Protocol):
class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...
+ async def unregister_shield(self, identifier: str) -> None: ...
+
class VectorDBsProtocolPrivate(Protocol):
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
index 9d359e053..dc0474e5d 100644
--- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
+++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
@@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
# The model will be validated during runtime when making inference calls
pass
+ async def unregister_shield(self, identifier: str) -> None:
+ # LlamaGuard doesn't need to do anything special for unregistration
+ # The routing table handles the removal from the registry
+ pass
+
async def run_shield(
self,
shield_id: str,
diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
index ff87889ea..d7a30d212 100644
--- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
+++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
@@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
+ async def unregister_shield(self, identifier: str) -> None:
+ pass
+
async def run_shield(
self,
shield_id: str,
diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py
index c43b51073..1895e7507 100644
--- a/llama_stack/providers/remote/safety/bedrock/bedrock.py
+++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py
@@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
)
+ async def unregister_shield(self, identifier: str) -> None:
+ pass
+
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
) -> RunShieldResponse:
diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py
index 411badb1c..7f17b1cb6 100644
--- a/llama_stack/providers/remote/safety/nvidia/nvidia.py
+++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py
@@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
+ async def unregister_shield(self, identifier: str) -> None:
+ pass
+
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse:
diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py
index 1a65f6aa1..e917b8c28 100644
--- a/llama_stack/providers/remote/safety/sambanova/sambanova.py
+++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py
@@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
):
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
+ async def unregister_shield(self, identifier: str) -> None:
+ pass
+
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse:
diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py
index c1b57cb4f..1ce663d8a 100644
--- a/tests/unit/distribution/routers/test_routing_tables.py
+++ b/tests/unit/distribution/routers/test_routing_tables.py
@@ -8,6 +8,8 @@
from unittest.mock import AsyncMock
+import pytest
+
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
@@ -78,6 +80,9 @@ class SafetyImpl(Impl):
async def register_shield(self, shield: Shield):
return shield
+ async def unregister_shield(self, shield_id: str):
+ return shield_id
+
class DatasetsImpl(Impl):
def __init__(self):
@@ -191,12 +196,42 @@ async def test_shields_routing_table(cached_disk_dist_registry):
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
shields = await table.list_shields()
-
assert len(shields.data) == 2
+
shield_ids = {s.identifier for s in shields.data}
assert "test-shield" in shield_ids
assert "test-shield-2" in shield_ids
+ # Test get specific shield
+ test_shield = await table.get_shield(identifier="test-shield")
+ assert test_shield is not None
+ assert test_shield.identifier == "test-shield"
+ assert test_shield.provider_id == "test_provider"
+ assert test_shield.provider_resource_id == "test-shield"
+ assert test_shield.params == {}
+
+ # Test get non-existent shield - should raise ValueError with specific message
+ with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
+ await table.get_shield(identifier="non-existent")
+
+ # Test unregistering shields
+ await table.unregister_shield(identifier="test-shield")
+ shields = await table.list_shields()
+
+ assert len(shields.data) == 1
+ shield_ids = {s.identifier for s in shields.data}
+ assert "test-shield" not in shield_ids
+ assert "test-shield-2" in shield_ids
+
+ # Unregister the remaining shield
+ await table.unregister_shield(identifier="test-shield-2")
+ shields = await table.list_shields()
+ assert len(shields.data) == 0
+
+ # Test unregistering non-existent shield - should raise ValueError with specific message
+ with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
+ await table.unregister_shield(identifier="non-existent")
+
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})