diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 38e53a438..c88b93a8c 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -1452,6 +1452,40 @@ } } ] + }, + "delete": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Shields" + ], + "description": "Unregister a shield.", + "parameters": [ + { + "name": "identifier", + "in": "path", + "description": "The identifier of the shield to unregister.", + "required": true, + "schema": { + "type": "string" + } + } + ] } }, "/v1/telemetry/traces/{trace_id}/spans/{span_id}": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 0df60ddf4..d3c322b7c 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -999,6 +999,31 @@ paths: required: true schema: type: string + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Shields + description: Unregister a shield. + parameters: + - name: identifier + in: path + description: >- + The identifier of the shield to unregister. + required: true + schema: + type: string /v1/telemetry/traces/{trace_id}/spans/{span_id}: get: responses: diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index ce1f73d8e..e636e3176 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -79,3 +79,11 @@ class Shields(Protocol): :returns: A Shield. """ ... + + @webmethod(route="/shields/{identifier:path}", method="DELETE") + async def unregister_shield(self, identifier: str) -> None: + """Unregister a shield. + + :param identifier: The identifier of the shield to unregister. + """ + ... diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py index 26ee8e722..f4273c7b5 100644 --- a/llama_stack/distribution/routers/safety.py +++ b/llama_stack/distribution/routers/safety.py @@ -43,6 +43,10 @@ class SafetyRouter(Safety): logger.debug(f"SafetyRouter.register_shield: {shield_id}") return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + async def unregister_shield(self, identifier: str) -> None: + logger.debug(f"SafetyRouter.unregister_shield: {identifier}") + return await self.routing_table.unregister_shield(identifier) + async def run_shield( self, shield_id: str, diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index caf0780fd..2e51db093 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -59,6 +59,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: return await p.unregister_vector_db(obj.identifier) elif api == Api.inference: return await p.unregister_model(obj.identifier) + elif api == Api.safety: + return await p.unregister_shield(obj.identifier) elif api == Api.datasetio: return await p.unregister_dataset(obj.identifier) elif api == Api.tool_runtime: diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/distribution/routing_tables/shields.py index 5215981b9..bd2b64453 100644 --- a/llama_stack/distribution/routing_tables/shields.py +++ b/llama_stack/distribution/routing_tables/shields.py @@ -55,3 +55,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): ) await self.register_object(shield) return shield + + async def unregister_shield(self, identifier: str) -> None: + existing_shield = await self.get_shield(identifier) + if existing_shield is None: + raise ValueError(f"Shield '{identifier}' not found") + logger.info(f"Shield {identifier} was unregistered successfully.") + await self.unregister_object(existing_shield) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 005bfbab8..d500e6261 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -62,6 +62,8 @@ class ModelsProtocolPrivate(Protocol): class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... + async def unregister_shield(self, identifier: str) -> None: ... + class VectorDBsProtocolPrivate(Protocol): async def register_vector_db(self, vector_db: VectorDB) -> None: ... diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 9d359e053..dc0474e5d 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): # The model will be validated during runtime when making inference calls pass + async def unregister_shield(self, identifier: str) -> None: + # LlamaGuard doesn't need to do anything special for unregistration + # The routing table handles the removal from the registry + pass + async def run_shield( self, shield_id: str, diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index ff87889ea..d7a30d212 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if shield.provider_resource_id != PROMPT_GUARD_MODEL: raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index c43b51073..1895e7507 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock" ) + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 411badb1c..7f17b1cb6 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): if not shield.provider_resource_id: raise ValueError("Shield model not provided.") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 1a65f6aa1..e917b8c28 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide ): logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None ) -> RunShieldResponse: diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index c1b57cb4f..1ce663d8a 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -8,6 +8,8 @@ from unittest.mock import AsyncMock +import pytest + from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api @@ -78,6 +80,9 @@ class SafetyImpl(Impl): async def register_shield(self, shield: Shield): return shield + async def unregister_shield(self, shield_id: str): + return shield_id + class DatasetsImpl(Impl): def __init__(self): @@ -191,12 +196,42 @@ async def test_shields_routing_table(cached_disk_dist_registry): await table.register_shield(shield_id="test-shield", provider_id="test_provider") await table.register_shield(shield_id="test-shield-2", provider_id="test_provider") shields = await table.list_shields() - assert len(shields.data) == 2 + shield_ids = {s.identifier for s in shields.data} assert "test-shield" in shield_ids assert "test-shield-2" in shield_ids + # Test get specific shield + test_shield = await table.get_shield(identifier="test-shield") + assert test_shield is not None + assert test_shield.identifier == "test-shield" + assert test_shield.provider_id == "test_provider" + assert test_shield.provider_resource_id == "test-shield" + assert test_shield.params == {} + + # Test get non-existent shield - should raise ValueError with specific message + with pytest.raises(ValueError, match="Shield 'non-existent' not found"): + await table.get_shield(identifier="non-existent") + + # Test unregistering shields + await table.unregister_shield(identifier="test-shield") + shields = await table.list_shields() + + assert len(shields.data) == 1 + shield_ids = {s.identifier for s in shields.data} + assert "test-shield" not in shield_ids + assert "test-shield-2" in shield_ids + + # Unregister the remaining shield + await table.unregister_shield(identifier="test-shield-2") + shields = await table.list_shields() + assert len(shields.data) == 0 + + # Test unregistering non-existent shield - should raise ValueError with specific message + with pytest.raises(ValueError, match="Shield 'non-existent' not found"): + await table.unregister_shield(identifier="non-existent") + async def test_vectordbs_routing_table(cached_disk_dist_registry): table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})