feat: create unregister shield API endpoint and CLI for shields management

This commit is contained in:
r3v5 2025-07-22 12:28:18 +01:00
parent 537dc693ee
commit 93be3aa9b6
No known key found for this signature in database
GPG key ID: 7758B9F272DE67D9
20 changed files with 563 additions and 1 deletions

View file

@ -1452,6 +1452,40 @@
}
}
]
},
"delete": {
"responses": {
"200": {
"description": "OK"
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Shields"
],
"description": "Unregister a shield.",
"parameters": [
{
"name": "identifier",
"in": "path",
"description": "The identifier of the shield to unregister.",
"required": true,
"schema": {
"type": "string"
}
}
]
}
},
"/v1/telemetry/traces/{trace_id}/spans/{span_id}": {

View file

@ -999,6 +999,31 @@ paths:
required: true
schema:
type: string
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
description: Unregister a shield.
parameters:
- name: identifier
in: path
description: >-
The identifier of the shield to unregister.
required: true
schema:
type: string
/v1/telemetry/traces/{trace_id}/spans/{span_id}:
get:
responses:

View file

@ -79,3 +79,11 @@ class Shields(Protocol):
:returns: A Shield.
"""
...
@webmethod(route="/shields/{identifier:path}", method="DELETE")
async def unregister_shield(self, identifier: str) -> None:
"""Unregister a shield.
:param identifier: The identifier of the shield to unregister.
"""
...

View file

@ -8,6 +8,7 @@ import argparse
from .download import Download
from .model import ModelParser
from .shield import ShieldParser
from .stack import StackParser
from .stack.utils import print_subcommand_description
from .verify_download import VerifyDownload
@ -31,6 +32,7 @@ class LlamaCLIParser:
# Add sub-commands
ModelParser.create(subparsers)
ShieldParser.create(subparsers)
StackParser.create(subparsers)
Download.create(subparsers)
VerifyDownload.create(subparsers)

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .shield import ShieldParser
__all__ = ["ShieldParser"]

View file

@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
import httpx
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class ShieldDescribe(Subcommand):
"""Show details about a shield"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"describe",
prog="llama shield describe",
description="Show details about a shield",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_shield_describe_cmd)
def _add_arguments(self):
self.parser.add_argument(
"shield_id",
type=str,
help="The identifier of the shield to describe",
)
self.parser.add_argument(
"--url",
type=str,
default="http://localhost:8321",
help="URL of the Llama Stack server (default: http://localhost:8321)",
)
self.parser.add_argument(
"--output-format",
choices=["table", "json"],
default="table",
help="Output format (default: table)",
)
def _run_shield_describe_cmd(self, args: argparse.Namespace) -> None:
try:
response = httpx.get(f"{args.url}/v1/shields/{args.shield_id}")
response.raise_for_status()
shield = response.json()
if args.output_format == "json":
print(json.dumps(shield, indent=2))
return
headers = ["Property", "Value"]
shield_id = shield.get("identifier", shield.get("shield_id", args.shield_id))
provider_id = shield.get("provider_id", "<Not Set>")
provider_shield_id = shield.get("provider_resource_id", shield.get("provider_shield_id", "<Not Set>"))
resource_type = shield.get("type", "shield")
rows = [
("Shield ID", shield_id),
("Provider ID", provider_id),
("Provider Shield ID", provider_shield_id),
("Resource Type", resource_type),
]
if shield.get("params"):
rows.append(("Parameters", json.dumps(shield["params"], indent=2)))
else:
rows.append(("Parameters", "<None>"))
print_table(
rows,
headers,
separate_rows=True,
)
except httpx.HTTPStatusError as e:
if e.response.status_code == 400 or e.response.status_code == 404:
print(f"Shield '{args.shield_id}' not found.")
else:
print(f"HTTP error {e.response.status_code}: {e.response.text}")
except httpx.RequestError as e:
print(f"Error connecting to Llama Stack server: {e}")
except Exception as e:
print(f"Error describing shield: {e}")

View file

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
import httpx
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class ShieldList(Subcommand):
"""List available shields"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama shield list",
description="Show available shields",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_shield_list_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--url",
type=str,
default="http://localhost:8321",
help="URL of the Llama Stack server (default: http://localhost:8321)",
)
self.parser.add_argument(
"--output-format",
choices=["table", "json"],
default="table",
help="Output format (default: table)",
)
def _run_shield_list_cmd(self, args: argparse.Namespace) -> None:
try:
response = httpx.get(f"{args.url}/v1/shields")
response.raise_for_status()
data = response.json()
shields = data.get("data", [])
if args.output_format == "json":
print(json.dumps(shields, indent=2))
return
if not shields:
print("No shields found.")
return
headers = ["Shield ID", "Provider ID", "Provider Shield ID", "Parameters"]
rows = []
for shield in shields:
params_str = ""
if shield.get("params"):
params_str = json.dumps(shield["params"], separators=(",", ":"))
if len(params_str) > 50:
params_str = params_str[:47] + "..."
rows.append(
[
shield.get("identifier", shield.get("shield_id", "-")),
shield.get("provider_id", "-"),
shield.get("provider_resource_id", shield.get("provider_shield_id", "-")),
params_str or "-",
]
)
print_table(
rows,
headers,
separate_rows=True,
)
except httpx.RequestError as e:
print(f"Error connecting to Llama Stack server: {e}")
except httpx.HTTPStatusError as e:
print(f"HTTP error {e.response.status_code}: {e.response.text}")
except Exception as e:
print(f"Error listing shields: {e}")

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
import sys
import httpx
from llama_stack.cli.subcommand import Subcommand
class ShieldRegister(Subcommand):
"""Register a new shield"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"register",
prog="llama shield register",
description="Register a new shield",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_shield_register_cmd)
def _add_arguments(self):
self.parser.add_argument(
"shield_id",
type=str,
help="The identifier for the shield to register",
)
self.parser.add_argument(
"--provider-id",
type=str,
help="The provider ID for the shield",
)
self.parser.add_argument(
"--provider-shield-id",
type=str,
help="The provider-specific shield identifier",
)
self.parser.add_argument(
"--params",
type=str,
help='Shield parameters as JSON string (e.g., \'{"key": "value"}\')',
)
self.parser.add_argument(
"--url",
type=str,
default="http://localhost:8321",
help="URL of the Llama Stack server (default: http://localhost:8321)",
)
def _run_shield_register_cmd(self, args: argparse.Namespace) -> None:
try:
params = None
if args.params:
try:
params = json.loads(args.params)
except json.JSONDecodeError as e:
print(f"Error parsing parameters JSON: {e}")
sys.exit(1)
payload = {
"shield_id": args.shield_id,
}
if args.provider_id:
payload["provider_id"] = args.provider_id
if args.provider_shield_id:
payload["provider_shield_id"] = args.provider_shield_id
if params:
payload["params"] = params
response = httpx.post(f"{args.url}/v1/shields", json=payload, headers={"Content-Type": "application/json"})
response.raise_for_status()
shield = response.json()
print(f" Shield '{shield.get('identifier', args.shield_id)}' registered successfully!")
print(f" Provider ID: {shield.get('provider_id', '<Not Set>')}")
print(
f" Provider Shield ID: {shield.get('provider_resource_id', shield.get('provider_shield_id', '<Not Set>'))}"
)
if shield.get("params"):
print(f" Parameters: {json.dumps(shield['params'], indent=2)}")
except httpx.HTTPStatusError as e:
if e.response.status_code == 400 and "already exists" in e.response.text:
print(f"Shield '{args.shield_id}' already exists.")
else:
print(f"HTTP error {e.response.status_code}: {e.response.text}")
sys.exit(1)
except httpx.RequestError as e:
print(f"Error connecting to Llama Stack server: {e}")
sys.exit(1)
except Exception as e:
print(f"Error registering shield: {e}")
sys.exit(1)

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand
from .describe import ShieldDescribe
from .list import ShieldList
from .register import ShieldRegister
from .unregister import ShieldUnregister
class ShieldParser(Subcommand):
"""Parser for shield commands"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"shield",
prog="llama shield",
description="Manage safety shields",
formatter_class=argparse.RawTextHelpFormatter,
)
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="shield_subcommands")
# Add shield sub-commands
ShieldList.create(subparsers)
ShieldRegister.create(subparsers)
ShieldDescribe.create(subparsers)
ShieldUnregister.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import sys
import httpx
from llama_stack.cli.subcommand import Subcommand
class ShieldUnregister(Subcommand):
"""Unregister a shield"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"unregister",
prog="llama shield unregister",
description="Unregister a shield",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_shield_unregister_cmd)
def _add_arguments(self):
self.parser.add_argument(
"shield_id",
type=str,
help="The identifier of the shield to unregister",
)
self.parser.add_argument(
"--url",
type=str,
default="http://localhost:8321",
help="URL of the Llama Stack server (default: http://localhost:8321)",
)
self.parser.add_argument(
"--force",
action="store_true",
help="Force unregister without confirmation",
)
def _run_shield_unregister_cmd(self, args: argparse.Namespace) -> None:
try:
try:
response = httpx.get(f"{args.url}/v1/shields/{args.shield_id}")
response.raise_for_status()
shield = response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 400 or e.response.status_code == 404:
print(f"Shield '{args.shield_id}' not found.")
sys.exit(1)
else:
raise
if not args.force:
shield_id = shield.get("identifier", shield.get("shield_id", args.shield_id))
provider_id = shield.get("provider_id", "<Not Set>")
provider_shield_id = shield.get("provider_resource_id", shield.get("provider_shield_id", "<Not Set>"))
print("Shield to unregister:")
print(f" - Shield ID: {shield_id}")
print(f" - Provider ID: {provider_id}")
print(f" - Provider Shield ID: {provider_shield_id}")
response_input = input(f"\nAre you sure you want to unregister shield '{args.shield_id}'? (y/N): ")
if response_input.lower() not in ["y", "yes"]:
print("Unregister cancelled.")
return
response = httpx.delete(f"{args.url}/v1/shields/{args.shield_id}")
response.raise_for_status()
print(f"Shield '{args.shield_id}' unregistered successfully!")
except httpx.HTTPStatusError as e:
if e.response.status_code == 400 or e.response.status_code == 404:
print(f"Shield '{args.shield_id}' not found.")
else:
print(f"HTTP error {e.response.status_code}: {e.response.text}")
sys.exit(1)
except httpx.RequestError as e:
print(f"Error connecting to Llama Stack server: {e}")
sys.exit(1)
except Exception as e:
print(f"Error unregistering shield: {e}")
sys.exit(1)

View file

@ -43,6 +43,10 @@ class SafetyRouter(Safety):
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def unregister_shield(self, identifier: str) -> None:
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
return await self.routing_table.unregister_shield(identifier)
async def run_shield(
self,
shield_id: str,

View file

@ -59,6 +59,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.safety:
return await p.unregister_shield(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:

View file

@ -55,3 +55,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
)
await self.register_object(shield)
return shield
async def unregister_shield(self, identifier: str) -> None:
existing_shield = await self.get_shield(identifier)
if existing_shield is None:
raise ValueError(f"Shield '{identifier}' not found")
logger.info(f"Shield {identifier} was unregistered successfully.")
await self.unregister_object(existing_shield)

View file

@ -51,6 +51,8 @@ class ModelsProtocolPrivate(Protocol):
class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...
async def unregister_shield(self, identifier: str) -> None: ...
class VectorDBsProtocolPrivate(Protocol):
async def register_vector_db(self, vector_db: VectorDB) -> None: ...

View file

@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
# The model will be validated during runtime when making inference calls
pass
async def unregister_shield(self, identifier: str) -> None:
# LlamaGuard doesn't need to do anything special for unregistration
# The routing table handles the removal from the registry
pass
async def run_shield(
self,
shield_id: str,

View file

@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield(
self,
shield_id: str,

View file

@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
)
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
) -> RunShieldResponse:

View file

@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse:

View file

@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
):
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse:

View file

@ -8,6 +8,8 @@
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
@ -53,6 +55,9 @@ class SafetyImpl(Impl):
async def register_shield(self, shield: Shield):
return shield
async def unregister_shield(self, shield_id: str):
return shield_id
class DatasetsImpl(Impl):
def __init__(self):
@ -166,12 +171,42 @@ async def test_shields_routing_table(cached_disk_dist_registry):
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
shields = await table.list_shields()
assert len(shields.data) == 2
shield_ids = {s.identifier for s in shields.data}
assert "test-shield" in shield_ids
assert "test-shield-2" in shield_ids
# Test get specific shield
test_shield = await table.get_shield(identifier="test-shield")
assert test_shield is not None
assert test_shield.identifier == "test-shield"
assert test_shield.provider_id == "test_provider"
assert test_shield.provider_resource_id == "test-shield"
assert test_shield.params == {}
# Test get non-existent shield - should raise ValueError with specific message
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
await table.get_shield(identifier="non-existent")
# Test unregistering shields
await table.unregister_shield(identifier="test-shield")
shields = await table.list_shields()
assert len(shields.data) == 1
shield_ids = {s.identifier for s in shields.data}
assert "test-shield" not in shield_ids
assert "test-shield-2" in shield_ids
# Unregister the remaining shield
await table.unregister_shield(identifier="test-shield-2")
shields = await table.list_shields()
assert len(shields.data) == 0
# Test unregistering non-existent shield - should raise ValueError with specific message
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
await table.unregister_shield(identifier="non-existent")
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})