From e12524af85fdc5f5e2a0721797bddae5ff9199a1 Mon Sep 17 00:00:00 2001 From: IAN MILLER <75687988+r3v5@users.noreply.github.com> Date: Tue, 5 Aug 2025 15:33:46 +0100 Subject: [PATCH] feat: create unregister shield API endpoint in Llama Stack (#2853) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Extend the Shields Protocol and implement the capability to unregister previously registered shields and CLI for shields management. Closes #2581 ## Test Plan First of, test API for shields 1. Install and start Ollama: `ollama serve` 2. Pull Llama Guard Model in Ollama: `ollama pull llama-guard3:8b` 3. Configure env variables: ``` export ENABLE_OLLAMA=ollama export OLLAMA_URL=http://localhost:11434 ``` 4. Build Llama Stack distro: `llama stack build --template starter --image-type venv ` 5. Start Llama Stack server: `llama stack run starter --port 8321` 6. Check if Ollama model is available: `curl -X GET http://localhost:8321/v1/models | jq '.data[] | select(.provider_id=="ollama")'` 7. Register a new Shield using Ollama provider: ``` curl -X POST http://localhost:8321/v1/shields \ -H "Content-Type: application/json" \ -d '{ "shield_id": "test-shield", "provider_id": "llama-guard", "provider_shield_id": "ollama/llama-guard3:8b", "params": {} }' ``` `{"identifier":"test-shield","provider_resource_id":"ollama/llama-guard3:8b","provider_id":"llama-guard","type":"shield","owner":{"principal":"","attributes":{}},"params":{}}% ` 8. Check if shield was registered: `curl -X GET http://localhost:8321/v1/shields/test-shield` `{"identifier":"test-shield","provider_resource_id":"ollama/llama-guard3:8b","provider_id":"llama-guard","type":"shield","owner":{"principal":"","attributes":{}},"params":{}}% ` 9. Run shield: ``` curl -X POST http://localhost:8321/v1/safety/run-shield \ -H "Content-Type: application/json" \ -d '{ "shield_id": "test-shield", "messages": [ { "role": "user", "content": "How can I hack into someone computer?" } ], "params": {} }' ``` `{"violation":{"violation_level":"error","user_message":"I can't answer that. Can I help with something else?","metadata":{"violation_type":"S2"}}}% ` 10. Unregister shield: `curl -X DELETE http://localhost:8321/v1/shields/test-shield` `null% ` 11. Verify shield was deleted: `curl -X GET http://localhost:8321/v1/shields/test-shield` `{"detail":"Invalid value: Shield 'test-shield' not found"}%` All tests passed ✅ ``` ========================================================================== 430 passed, 194 warnings in 19.54s ========================================================================== /Users/iamiller/GitHub/llama-stack/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/async_client_cleanup.py:78: RuntimeWarning: coroutine 'close_litellm_async_clients' was never awaited loop.close() RuntimeWarning: Enable tracemalloc to get the object allocation traceback Wrote HTML report to htmlcov-3.12/index.html ``` --- docs/_static/llama-stack-spec.html | 34 +++++++++++++++++ docs/_static/llama-stack-spec.yaml | 25 +++++++++++++ llama_stack/apis/shields/shields.py | 8 ++++ llama_stack/core/routers/safety.py | 4 ++ llama_stack/core/routing_tables/common.py | 2 + llama_stack/core/routing_tables/shields.py | 4 ++ llama_stack/providers/datatypes.py | 2 + .../inline/safety/llama_guard/llama_guard.py | 5 +++ .../safety/prompt_guard/prompt_guard.py | 3 ++ .../remote/safety/bedrock/bedrock.py | 3 ++ .../providers/remote/safety/nvidia/nvidia.py | 3 ++ .../remote/safety/sambanova/sambanova.py | 3 ++ .../routers/test_routing_tables.py | 37 ++++++++++++++++++- 13 files changed, 132 insertions(+), 1 deletion(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index f9af10165..79b9ede30 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -1452,6 +1452,40 @@ } } ] + }, + "delete": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Shields" + ], + "description": "Unregister a shield.", + "parameters": [ + { + "name": "identifier", + "in": "path", + "description": "The identifier of the shield to unregister.", + "required": true, + "schema": { + "type": "string" + } + } + ] } }, "/v1/telemetry/traces/{trace_id}/spans/{span_id}": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index d2c41b2bf..a15a2824e 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -999,6 +999,31 @@ paths: required: true schema: type: string + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Shields + description: Unregister a shield. + parameters: + - name: identifier + in: path + description: >- + The identifier of the shield to unregister. + required: true + schema: + type: string /v1/telemetry/traces/{trace_id}/spans/{span_id}: get: responses: diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 5d3e55c55..ec1b85349 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -83,3 +83,11 @@ class Shields(Protocol): :returns: A Shield. """ ... + + @webmethod(route="/shields/{identifier:path}", method="DELETE") + async def unregister_shield(self, identifier: str) -> None: + """Unregister a shield. + + :param identifier: The identifier of the shield to unregister. + """ + ... diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index 26ee8e722..f4273c7b5 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -43,6 +43,10 @@ class SafetyRouter(Safety): logger.debug(f"SafetyRouter.register_shield: {shield_id}") return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + async def unregister_shield(self, identifier: str) -> None: + logger.debug(f"SafetyRouter.unregister_shield: {identifier}") + return await self.routing_table.unregister_shield(identifier) + async def run_shield( self, shield_id: str, diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index 4be3de42d..339ff6da4 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -60,6 +60,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: return await p.unregister_vector_db(obj.identifier) elif api == Api.inference: return await p.unregister_model(obj.identifier) + elif api == Api.safety: + return await p.unregister_shield(obj.identifier) elif api == Api.datasetio: return await p.unregister_dataset(obj.identifier) elif api == Api.tool_runtime: diff --git a/llama_stack/core/routing_tables/shields.py b/llama_stack/core/routing_tables/shields.py index 0c592601a..e08f35bfc 100644 --- a/llama_stack/core/routing_tables/shields.py +++ b/llama_stack/core/routing_tables/shields.py @@ -55,3 +55,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): ) await self.register_object(shield) return shield + + async def unregister_shield(self, identifier: str) -> None: + existing_shield = await self.get_shield(identifier) + await self.unregister_object(existing_shield) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index f9f463bf9..5e15dd8e1 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -65,6 +65,8 @@ class ModelsProtocolPrivate(Protocol): class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... + async def unregister_shield(self, identifier: str) -> None: ... + class VectorDBsProtocolPrivate(Protocol): async def register_vector_db(self, vector_db: VectorDB) -> None: ... diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index c580adfad..4a7e99e00 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if not model_id: raise ValueError("Llama Guard shield must have a model id") + async def unregister_shield(self, identifier: str) -> None: + # LlamaGuard doesn't need to do anything special for unregistration + # The routing table handles the removal from the registry + pass + async def run_shield( self, shield_id: str, diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index ee645a41d..796771ee1 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if shield.provider_resource_id != PROMPT_GUARD_MODEL: raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index c43b51073..1895e7507 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock" ) + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 411badb1c..7f17b1cb6 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): if not shield.provider_resource_id: raise ValueError("Shield model not provided.") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 3e0d03956..6c7190afe 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide ): logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None ) -> RunShieldResponse: diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 155ad0142..2652f5c8d 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -8,6 +8,8 @@ from unittest.mock import AsyncMock +import pytest + from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api @@ -78,6 +80,9 @@ class SafetyImpl(Impl): async def register_shield(self, shield: Shield): return shield + async def unregister_shield(self, shield_id: str): + return shield_id + class DatasetsImpl(Impl): def __init__(self): @@ -191,12 +196,42 @@ async def test_shields_routing_table(cached_disk_dist_registry): await table.register_shield(shield_id="test-shield", provider_id="test_provider") await table.register_shield(shield_id="test-shield-2", provider_id="test_provider") shields = await table.list_shields() - assert len(shields.data) == 2 + shield_ids = {s.identifier for s in shields.data} assert "test-shield" in shield_ids assert "test-shield-2" in shield_ids + # Test get specific shield + test_shield = await table.get_shield(identifier="test-shield") + assert test_shield is not None + assert test_shield.identifier == "test-shield" + assert test_shield.provider_id == "test_provider" + assert test_shield.provider_resource_id == "test-shield" + assert test_shield.params == {} + + # Test get non-existent shield - should raise ValueError with specific message + with pytest.raises(ValueError, match="Shield 'non-existent' not found"): + await table.get_shield(identifier="non-existent") + + # Test unregistering shields + await table.unregister_shield(identifier="test-shield") + shields = await table.list_shields() + + assert len(shields.data) == 1 + shield_ids = {s.identifier for s in shields.data} + assert "test-shield" not in shield_ids + assert "test-shield-2" in shield_ids + + # Unregister the remaining shield + await table.unregister_shield(identifier="test-shield-2") + shields = await table.list_shields() + assert len(shields.data) == 0 + + # Test unregistering non-existent shield - should raise ValueError with specific message + with pytest.raises(ValueError, match="Shield 'non-existent' not found"): + await table.unregister_shield(identifier="non-existent") + async def test_vectordbs_routing_table(cached_disk_dist_registry): table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})