mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
feat: create unregister shield API endpoint in Llama Stack (#2853)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 10s
Integration Tests (Replay) / discover-tests (push) Successful in 13s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 24s
Test External API and Providers / test-external (venv) (push) Failing after 12s
Unit Tests / unit-tests (3.13) (push) Failing after 10s
Update ReadTheDocs / update-readthedocs (push) Failing after 9s
Python Package Build Test / build (3.13) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 27s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 29s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 22s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 21s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 35s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 39s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 35s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 35s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 1m2s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 1m4s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 1m2s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 7s
Pre-commit / pre-commit (push) Successful in 2m21s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 10s
Integration Tests (Replay) / discover-tests (push) Successful in 13s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 24s
Test External API and Providers / test-external (venv) (push) Failing after 12s
Unit Tests / unit-tests (3.13) (push) Failing after 10s
Update ReadTheDocs / update-readthedocs (push) Failing after 9s
Python Package Build Test / build (3.13) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 27s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 29s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 22s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 21s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 35s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 39s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 35s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 35s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 1m2s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 1m4s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 1m2s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 7s
Pre-commit / pre-commit (push) Successful in 2m21s
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> Extend the Shields Protocol and implement the capability to unregister previously registered shields and CLI for shields management. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> Closes #2581 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> First of, test API for shields 1. Install and start Ollama: `ollama serve` 2. Pull Llama Guard Model in Ollama: `ollama pull llama-guard3:8b` 3. Configure env variables: ``` export ENABLE_OLLAMA=ollama export OLLAMA_URL=http://localhost:11434 ``` 4. Build Llama Stack distro: `llama stack build --template starter --image-type venv ` 5. Start Llama Stack server: `llama stack run starter --port 8321` 6. Check if Ollama model is available: `curl -X GET http://localhost:8321/v1/models | jq '.data[] | select(.provider_id=="ollama")'` 7. Register a new Shield using Ollama provider: ``` curl -X POST http://localhost:8321/v1/shields \ -H "Content-Type: application/json" \ -d '{ "shield_id": "test-shield", "provider_id": "llama-guard", "provider_shield_id": "ollama/llama-guard3:8b", "params": {} }' ``` `{"identifier":"test-shield","provider_resource_id":"ollama/llama-guard3:8b","provider_id":"llama-guard","type":"shield","owner":{"principal":"","attributes":{}},"params":{}}% ` 8. Check if shield was registered: `curl -X GET http://localhost:8321/v1/shields/test-shield` `{"identifier":"test-shield","provider_resource_id":"ollama/llama-guard3:8b","provider_id":"llama-guard","type":"shield","owner":{"principal":"","attributes":{}},"params":{}}% ` 9. Run shield: ``` curl -X POST http://localhost:8321/v1/safety/run-shield \ -H "Content-Type: application/json" \ -d '{ "shield_id": "test-shield", "messages": [ { "role": "user", "content": "How can I hack into someone computer?" } ], "params": {} }' ``` `{"violation":{"violation_level":"error","user_message":"I can't answer that. Can I help with something else?","metadata":{"violation_type":"S2"}}}% ` 10. Unregister shield: `curl -X DELETE http://localhost:8321/v1/shields/test-shield` `null% ` 11. Verify shield was deleted: `curl -X GET http://localhost:8321/v1/shields/test-shield` `{"detail":"Invalid value: Shield 'test-shield' not found"}%` All tests passed ✅ ``` ========================================================================== 430 passed, 194 warnings in 19.54s ========================================================================== /Users/iamiller/GitHub/llama-stack/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/async_client_cleanup.py:78: RuntimeWarning: coroutine 'close_litellm_async_clients' was never awaited loop.close() RuntimeWarning: Enable tracemalloc to get the object allocation traceback Wrote HTML report to htmlcov-3.12/index.html ```
This commit is contained in:
parent
e565b91182
commit
e12524af85
13 changed files with 132 additions and 1 deletions
34
docs/_static/llama-stack-spec.html
vendored
34
docs/_static/llama-stack-spec.html
vendored
|
@ -1452,6 +1452,40 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"delete": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK"
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Shields"
|
||||||
|
],
|
||||||
|
"description": "Unregister a shield.",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "identifier",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The identifier of the shield to unregister.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/telemetry/traces/{trace_id}/spans/{span_id}": {
|
"/v1/telemetry/traces/{trace_id}/spans/{span_id}": {
|
||||||
|
|
25
docs/_static/llama-stack-spec.yaml
vendored
25
docs/_static/llama-stack-spec.yaml
vendored
|
@ -999,6 +999,31 @@ paths:
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
delete:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: OK
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Shields
|
||||||
|
description: Unregister a shield.
|
||||||
|
parameters:
|
||||||
|
- name: identifier
|
||||||
|
in: path
|
||||||
|
description: >-
|
||||||
|
The identifier of the shield to unregister.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
/v1/telemetry/traces/{trace_id}/spans/{span_id}:
|
/v1/telemetry/traces/{trace_id}/spans/{span_id}:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
|
|
@ -83,3 +83,11 @@ class Shields(Protocol):
|
||||||
:returns: A Shield.
|
:returns: A Shield.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/shields/{identifier:path}", method="DELETE")
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
"""Unregister a shield.
|
||||||
|
|
||||||
|
:param identifier: The identifier of the shield to unregister.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -43,6 +43,10 @@ class SafetyRouter(Safety):
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
||||||
|
return await self.routing_table.unregister_shield(identifier)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
|
|
|
@ -60,6 +60,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
return await p.unregister_vector_db(obj.identifier)
|
return await p.unregister_vector_db(obj.identifier)
|
||||||
elif api == Api.inference:
|
elif api == Api.inference:
|
||||||
return await p.unregister_model(obj.identifier)
|
return await p.unregister_model(obj.identifier)
|
||||||
|
elif api == Api.safety:
|
||||||
|
return await p.unregister_shield(obj.identifier)
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
return await p.unregister_dataset(obj.identifier)
|
return await p.unregister_dataset(obj.identifier)
|
||||||
elif api == Api.tool_runtime:
|
elif api == Api.tool_runtime:
|
||||||
|
|
|
@ -55,3 +55,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
)
|
)
|
||||||
await self.register_object(shield)
|
await self.register_object(shield)
|
||||||
return shield
|
return shield
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
existing_shield = await self.get_shield(identifier)
|
||||||
|
await self.unregister_object(existing_shield)
|
||||||
|
|
|
@ -65,6 +65,8 @@ class ModelsProtocolPrivate(Protocol):
|
||||||
class ShieldsProtocolPrivate(Protocol):
|
class ShieldsProtocolPrivate(Protocol):
|
||||||
async def register_shield(self, shield: Shield) -> None: ...
|
async def register_shield(self, shield: Shield) -> None: ...
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsProtocolPrivate(Protocol):
|
class VectorDBsProtocolPrivate(Protocol):
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
|
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
|
||||||
|
|
|
@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
if not model_id:
|
if not model_id:
|
||||||
raise ValueError("Llama Guard shield must have a model id")
|
raise ValueError("Llama Guard shield must have a model id")
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
# LlamaGuard doesn't need to do anything special for unregistration
|
||||||
|
# The routing table handles the removal from the registry
|
||||||
|
pass
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
|
|
|
@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
|
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
|
||||||
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
|
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
|
|
|
@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
|
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
|
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
|
|
@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
if not shield.provider_resource_id:
|
if not shield.provider_resource_id:
|
||||||
raise ValueError("Shield model not provided.")
|
raise ValueError("Shield model not provided.")
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
|
|
@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
||||||
):
|
):
|
||||||
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
|
|
@ -8,6 +8,8 @@
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
|
@ -78,6 +80,9 @@ class SafetyImpl(Impl):
|
||||||
async def register_shield(self, shield: Shield):
|
async def register_shield(self, shield: Shield):
|
||||||
return shield
|
return shield
|
||||||
|
|
||||||
|
async def unregister_shield(self, shield_id: str):
|
||||||
|
return shield_id
|
||||||
|
|
||||||
|
|
||||||
class DatasetsImpl(Impl):
|
class DatasetsImpl(Impl):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -191,12 +196,42 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
||||||
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
|
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
|
||||||
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
|
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
|
||||||
shields = await table.list_shields()
|
shields = await table.list_shields()
|
||||||
|
|
||||||
assert len(shields.data) == 2
|
assert len(shields.data) == 2
|
||||||
|
|
||||||
shield_ids = {s.identifier for s in shields.data}
|
shield_ids = {s.identifier for s in shields.data}
|
||||||
assert "test-shield" in shield_ids
|
assert "test-shield" in shield_ids
|
||||||
assert "test-shield-2" in shield_ids
|
assert "test-shield-2" in shield_ids
|
||||||
|
|
||||||
|
# Test get specific shield
|
||||||
|
test_shield = await table.get_shield(identifier="test-shield")
|
||||||
|
assert test_shield is not None
|
||||||
|
assert test_shield.identifier == "test-shield"
|
||||||
|
assert test_shield.provider_id == "test_provider"
|
||||||
|
assert test_shield.provider_resource_id == "test-shield"
|
||||||
|
assert test_shield.params == {}
|
||||||
|
|
||||||
|
# Test get non-existent shield - should raise ValueError with specific message
|
||||||
|
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
|
||||||
|
await table.get_shield(identifier="non-existent")
|
||||||
|
|
||||||
|
# Test unregistering shields
|
||||||
|
await table.unregister_shield(identifier="test-shield")
|
||||||
|
shields = await table.list_shields()
|
||||||
|
|
||||||
|
assert len(shields.data) == 1
|
||||||
|
shield_ids = {s.identifier for s in shields.data}
|
||||||
|
assert "test-shield" not in shield_ids
|
||||||
|
assert "test-shield-2" in shield_ids
|
||||||
|
|
||||||
|
# Unregister the remaining shield
|
||||||
|
await table.unregister_shield(identifier="test-shield-2")
|
||||||
|
shields = await table.list_shields()
|
||||||
|
assert len(shields.data) == 0
|
||||||
|
|
||||||
|
# Test unregistering non-existent shield - should raise ValueError with specific message
|
||||||
|
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
|
||||||
|
await table.unregister_shield(identifier="non-existent")
|
||||||
|
|
||||||
|
|
||||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue