From ac6034ed1e249d61ba2a842174b5ce6ec95d206f Mon Sep 17 00:00:00 2001 From: r3v5 Date: Tue, 22 Jul 2025 12:28:18 +0100 Subject: [PATCH] feat: create unregister shield API endpoint --- docs/_static/llama-stack-spec.html | 34 +++++++++++++++++ docs/_static/llama-stack-spec.yaml | 25 +++++++++++++ llama_stack/apis/shields/shields.py | 8 ++++ llama_stack/distribution/routers/safety.py | 4 ++ .../distribution/routing_tables/common.py | 2 + .../distribution/routing_tables/shields.py | 7 ++++ llama_stack/providers/datatypes.py | 2 + .../inline/safety/llama_guard/llama_guard.py | 5 +++ .../safety/prompt_guard/prompt_guard.py | 3 ++ .../remote/safety/bedrock/bedrock.py | 3 ++ .../providers/remote/safety/nvidia/nvidia.py | 3 ++ .../remote/safety/sambanova/sambanova.py | 3 ++ .../routers/test_routing_tables.py | 37 ++++++++++++++++++- 13 files changed, 135 insertions(+), 1 deletion(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index d7801ba1c..f460e50e6 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -1452,6 +1452,40 @@ } } ] + }, + "delete": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Shields" + ], + "description": "Unregister a shield.", + "parameters": [ + { + "name": "identifier", + "in": "path", + "description": "The identifier of the shield to unregister.", + "required": true, + "schema": { + "type": "string" + } + } + ] } }, "/v1/telemetry/traces/{trace_id}/spans/{span_id}": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index be02e1e42..621ed6d8b 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -999,6 +999,31 @@ paths: required: true schema: type: string + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Shields + description: Unregister a shield. + parameters: + - name: identifier + in: path + description: >- + The identifier of the shield to unregister. + required: true + schema: + type: string /v1/telemetry/traces/{trace_id}/spans/{span_id}: get: responses: diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index ce1f73d8e..e636e3176 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -79,3 +79,11 @@ class Shields(Protocol): :returns: A Shield. """ ... + + @webmethod(route="/shields/{identifier:path}", method="DELETE") + async def unregister_shield(self, identifier: str) -> None: + """Unregister a shield. + + :param identifier: The identifier of the shield to unregister. + """ + ... diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py index 9761d2db0..e6f3ee4dc 100644 --- a/llama_stack/distribution/routers/safety.py +++ b/llama_stack/distribution/routers/safety.py @@ -43,6 +43,10 @@ class SafetyRouter(Safety): logger.debug(f"SafetyRouter.register_shield: {shield_id}") return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + async def unregister_shield(self, identifier: str) -> None: + logger.debug(f"SafetyRouter.unregister_shield: {identifier}") + return await self.routing_table.unregister_shield(identifier) + async def run_shield( self, shield_id: str, diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 7f7de32fe..1386e7e86 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: return await p.unregister_vector_db(obj.identifier) elif api == Api.inference: return await p.unregister_model(obj.identifier) + elif api == Api.safety: + return await p.unregister_shield(obj.identifier) elif api == Api.datasetio: return await p.unregister_dataset(obj.identifier) elif api == Api.tool_runtime: diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/distribution/routing_tables/shields.py index 5215981b9..bd2b64453 100644 --- a/llama_stack/distribution/routing_tables/shields.py +++ b/llama_stack/distribution/routing_tables/shields.py @@ -55,3 +55,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): ) await self.register_object(shield) return shield + + async def unregister_shield(self, identifier: str) -> None: + existing_shield = await self.get_shield(identifier) + if existing_shield is None: + raise ValueError(f"Shield '{identifier}' not found") + logger.info(f"Shield {identifier} was unregistered successfully.") + await self.unregister_object(existing_shield) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index efe8a98fe..e53d001c4 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -51,6 +51,8 @@ class ModelsProtocolPrivate(Protocol): class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... + async def unregister_shield(self, identifier: str) -> None: ... + class VectorDBsProtocolPrivate(Protocol): async def register_vector_db(self, vector_db: VectorDB) -> None: ... diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 9d359e053..dc0474e5d 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): # The model will be validated during runtime when making inference calls pass + async def unregister_shield(self, identifier: str) -> None: + # LlamaGuard doesn't need to do anything special for unregistration + # The routing table handles the removal from the registry + pass + async def run_shield( self, shield_id: str, diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index ff87889ea..d7a30d212 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if shield.provider_resource_id != PROMPT_GUARD_MODEL: raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index c43b51073..1895e7507 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock" ) + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 411badb1c..7f17b1cb6 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): if not shield.provider_resource_id: raise ValueError("Shield model not provided.") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 1a65f6aa1..e917b8c28 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide ): logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None ) -> RunShieldResponse: diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 3ba042bd9..3587738ab 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -8,6 +8,8 @@ from unittest.mock import AsyncMock +import pytest + from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api @@ -53,6 +55,9 @@ class SafetyImpl(Impl): async def register_shield(self, shield: Shield): return shield + async def unregister_shield(self, shield_id: str): + return shield_id + class VectorDBImpl(Impl): def __init__(self): @@ -166,12 +171,42 @@ async def test_shields_routing_table(cached_disk_dist_registry): await table.register_shield(shield_id="test-shield", provider_id="test_provider") await table.register_shield(shield_id="test-shield-2", provider_id="test_provider") shields = await table.list_shields() - assert len(shields.data) == 2 + shield_ids = {s.identifier for s in shields.data} assert "test-shield" in shield_ids assert "test-shield-2" in shield_ids + # Test get specific shield + test_shield = await table.get_shield(identifier="test-shield") + assert test_shield is not None + assert test_shield.identifier == "test-shield" + assert test_shield.provider_id == "test_provider" + assert test_shield.provider_resource_id == "test-shield" + assert test_shield.params == {} + + # Test get non-existent shield - should raise ValueError with specific message + with pytest.raises(ValueError, match="Shield 'non-existent' not found"): + await table.get_shield(identifier="non-existent") + + # Test unregistering shields + await table.unregister_shield(identifier="test-shield") + shields = await table.list_shields() + + assert len(shields.data) == 1 + shield_ids = {s.identifier for s in shields.data} + assert "test-shield" not in shield_ids + assert "test-shield-2" in shield_ids + + # Unregister the remaining shield + await table.unregister_shield(identifier="test-shield-2") + shields = await table.list_shields() + assert len(shields.data) == 0 + + # Test unregistering non-existent shield - should raise ValueError with specific message + with pytest.raises(ValueError, match="Shield 'non-existent' not found"): + await table.unregister_shield(identifier="non-existent") + async def test_vectordbs_routing_table(cached_disk_dist_registry): table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})