This commit is contained in:
IAN MILLER 2025-07-25 22:55:46 -04:00 committed by GitHub
commit 53ca20fade
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 563 additions and 1 deletions

View file

@ -1452,6 +1452,40 @@
} }
} }
] ]
},
"delete": {
"responses": {
"200": {
"description": "OK"
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Shields"
],
"description": "Unregister a shield.",
"parameters": [
{
"name": "identifier",
"in": "path",
"description": "The identifier of the shield to unregister.",
"required": true,
"schema": {
"type": "string"
}
}
]
} }
}, },
"/v1/telemetry/traces/{trace_id}/spans/{span_id}": { "/v1/telemetry/traces/{trace_id}/spans/{span_id}": {

View file

@ -999,6 +999,31 @@ paths:
required: true required: true
schema: schema:
type: string type: string
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Shields
description: Unregister a shield.
parameters:
- name: identifier
in: path
description: >-
The identifier of the shield to unregister.
required: true
schema:
type: string
/v1/telemetry/traces/{trace_id}/spans/{span_id}: /v1/telemetry/traces/{trace_id}/spans/{span_id}:
get: get:
responses: responses:

View file

@ -79,3 +79,11 @@ class Shields(Protocol):
:returns: A Shield. :returns: A Shield.
""" """
... ...
@webmethod(route="/shields/{identifier:path}", method="DELETE")
async def unregister_shield(self, identifier: str) -> None:
"""Unregister a shield.
:param identifier: The identifier of the shield to unregister.
"""
...

View file

@ -8,6 +8,7 @@ import argparse
from .download import Download from .download import Download
from .model import ModelParser from .model import ModelParser
from .shield import ShieldParser
from .stack import StackParser from .stack import StackParser
from .stack.utils import print_subcommand_description from .stack.utils import print_subcommand_description
from .verify_download import VerifyDownload from .verify_download import VerifyDownload
@ -31,6 +32,7 @@ class LlamaCLIParser:
# Add sub-commands # Add sub-commands
ModelParser.create(subparsers) ModelParser.create(subparsers)
ShieldParser.create(subparsers)
StackParser.create(subparsers) StackParser.create(subparsers)
Download.create(subparsers) Download.create(subparsers)
VerifyDownload.create(subparsers) VerifyDownload.create(subparsers)

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .shield import ShieldParser
__all__ = ["ShieldParser"]

View file

@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
import httpx
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class ShieldDescribe(Subcommand):
"""Show details about a shield"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"describe",
prog="llama shield describe",
description="Show details about a shield",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_shield_describe_cmd)
def _add_arguments(self):
self.parser.add_argument(
"shield_id",
type=str,
help="The identifier of the shield to describe",
)
self.parser.add_argument(
"--url",
type=str,
default="http://localhost:8321",
help="URL of the Llama Stack server (default: http://localhost:8321)",
)
self.parser.add_argument(
"--output-format",
choices=["table", "json"],
default="table",
help="Output format (default: table)",
)
def _run_shield_describe_cmd(self, args: argparse.Namespace) -> None:
try:
response = httpx.get(f"{args.url}/v1/shields/{args.shield_id}")
response.raise_for_status()
shield = response.json()
if args.output_format == "json":
print(json.dumps(shield, indent=2))
return
headers = ["Property", "Value"]
shield_id = shield.get("identifier", shield.get("shield_id", args.shield_id))
provider_id = shield.get("provider_id", "<Not Set>")
provider_shield_id = shield.get("provider_resource_id", shield.get("provider_shield_id", "<Not Set>"))
resource_type = shield.get("type", "shield")
rows = [
("Shield ID", shield_id),
("Provider ID", provider_id),
("Provider Shield ID", provider_shield_id),
("Resource Type", resource_type),
]
if shield.get("params"):
rows.append(("Parameters", json.dumps(shield["params"], indent=2)))
else:
rows.append(("Parameters", "<None>"))
print_table(
rows,
headers,
separate_rows=True,
)
except httpx.HTTPStatusError as e:
if e.response.status_code == 400 or e.response.status_code == 404:
print(f"Shield '{args.shield_id}' not found.")
else:
print(f"HTTP error {e.response.status_code}: {e.response.text}")
except httpx.RequestError as e:
print(f"Error connecting to Llama Stack server: {e}")
except Exception as e:
print(f"Error describing shield: {e}")

View file

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
import httpx
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class ShieldList(Subcommand):
"""List available shields"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama shield list",
description="Show available shields",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_shield_list_cmd)
def _add_arguments(self):
self.parser.add_argument(
"--url",
type=str,
default="http://localhost:8321",
help="URL of the Llama Stack server (default: http://localhost:8321)",
)
self.parser.add_argument(
"--output-format",
choices=["table", "json"],
default="table",
help="Output format (default: table)",
)
def _run_shield_list_cmd(self, args: argparse.Namespace) -> None:
try:
response = httpx.get(f"{args.url}/v1/shields")
response.raise_for_status()
data = response.json()
shields = data.get("data", [])
if args.output_format == "json":
print(json.dumps(shields, indent=2))
return
if not shields:
print("No shields found.")
return
headers = ["Shield ID", "Provider ID", "Provider Shield ID", "Parameters"]
rows = []
for shield in shields:
params_str = ""
if shield.get("params"):
params_str = json.dumps(shield["params"], separators=(",", ":"))
if len(params_str) > 50:
params_str = params_str[:47] + "..."
rows.append(
[
shield.get("identifier", shield.get("shield_id", "-")),
shield.get("provider_id", "-"),
shield.get("provider_resource_id", shield.get("provider_shield_id", "-")),
params_str or "-",
]
)
print_table(
rows,
headers,
separate_rows=True,
)
except httpx.RequestError as e:
print(f"Error connecting to Llama Stack server: {e}")
except httpx.HTTPStatusError as e:
print(f"HTTP error {e.response.status_code}: {e.response.text}")
except Exception as e:
print(f"Error listing shields: {e}")

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import json
import sys
import httpx
from llama_stack.cli.subcommand import Subcommand
class ShieldRegister(Subcommand):
"""Register a new shield"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"register",
prog="llama shield register",
description="Register a new shield",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_shield_register_cmd)
def _add_arguments(self):
self.parser.add_argument(
"shield_id",
type=str,
help="The identifier for the shield to register",
)
self.parser.add_argument(
"--provider-id",
type=str,
help="The provider ID for the shield",
)
self.parser.add_argument(
"--provider-shield-id",
type=str,
help="The provider-specific shield identifier",
)
self.parser.add_argument(
"--params",
type=str,
help='Shield parameters as JSON string (e.g., \'{"key": "value"}\')',
)
self.parser.add_argument(
"--url",
type=str,
default="http://localhost:8321",
help="URL of the Llama Stack server (default: http://localhost:8321)",
)
def _run_shield_register_cmd(self, args: argparse.Namespace) -> None:
try:
params = None
if args.params:
try:
params = json.loads(args.params)
except json.JSONDecodeError as e:
print(f"Error parsing parameters JSON: {e}")
sys.exit(1)
payload = {
"shield_id": args.shield_id,
}
if args.provider_id:
payload["provider_id"] = args.provider_id
if args.provider_shield_id:
payload["provider_shield_id"] = args.provider_shield_id
if params:
payload["params"] = params
response = httpx.post(f"{args.url}/v1/shields", json=payload, headers={"Content-Type": "application/json"})
response.raise_for_status()
shield = response.json()
print(f" Shield '{shield.get('identifier', args.shield_id)}' registered successfully!")
print(f" Provider ID: {shield.get('provider_id', '<Not Set>')}")
print(
f" Provider Shield ID: {shield.get('provider_resource_id', shield.get('provider_shield_id', '<Not Set>'))}"
)
if shield.get("params"):
print(f" Parameters: {json.dumps(shield['params'], indent=2)}")
except httpx.HTTPStatusError as e:
if e.response.status_code == 400 and "already exists" in e.response.text:
print(f"Shield '{args.shield_id}' already exists.")
else:
print(f"HTTP error {e.response.status_code}: {e.response.text}")
sys.exit(1)
except httpx.RequestError as e:
print(f"Error connecting to Llama Stack server: {e}")
sys.exit(1)
except Exception as e:
print(f"Error registering shield: {e}")
sys.exit(1)

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand
from .describe import ShieldDescribe
from .list import ShieldList
from .register import ShieldRegister
from .unregister import ShieldUnregister
class ShieldParser(Subcommand):
"""Parser for shield commands"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"shield",
prog="llama shield",
description="Manage safety shields",
formatter_class=argparse.RawTextHelpFormatter,
)
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="shield_subcommands")
# Add shield sub-commands
ShieldList.create(subparsers)
ShieldRegister.create(subparsers)
ShieldDescribe.create(subparsers)
ShieldUnregister.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import sys
import httpx
from llama_stack.cli.subcommand import Subcommand
class ShieldUnregister(Subcommand):
"""Unregister a shield"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"unregister",
prog="llama shield unregister",
description="Unregister a shield",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_shield_unregister_cmd)
def _add_arguments(self):
self.parser.add_argument(
"shield_id",
type=str,
help="The identifier of the shield to unregister",
)
self.parser.add_argument(
"--url",
type=str,
default="http://localhost:8321",
help="URL of the Llama Stack server (default: http://localhost:8321)",
)
self.parser.add_argument(
"--force",
action="store_true",
help="Force unregister without confirmation",
)
def _run_shield_unregister_cmd(self, args: argparse.Namespace) -> None:
try:
try:
response = httpx.get(f"{args.url}/v1/shields/{args.shield_id}")
response.raise_for_status()
shield = response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 400 or e.response.status_code == 404:
print(f"Shield '{args.shield_id}' not found.")
sys.exit(1)
else:
raise
if not args.force:
shield_id = shield.get("identifier", shield.get("shield_id", args.shield_id))
provider_id = shield.get("provider_id", "<Not Set>")
provider_shield_id = shield.get("provider_resource_id", shield.get("provider_shield_id", "<Not Set>"))
print("Shield to unregister:")
print(f" - Shield ID: {shield_id}")
print(f" - Provider ID: {provider_id}")
print(f" - Provider Shield ID: {provider_shield_id}")
response_input = input(f"\nAre you sure you want to unregister shield '{args.shield_id}'? (y/N): ")
if response_input.lower() not in ["y", "yes"]:
print("Unregister cancelled.")
return
response = httpx.delete(f"{args.url}/v1/shields/{args.shield_id}")
response.raise_for_status()
print(f"Shield '{args.shield_id}' unregistered successfully!")
except httpx.HTTPStatusError as e:
if e.response.status_code == 400 or e.response.status_code == 404:
print(f"Shield '{args.shield_id}' not found.")
else:
print(f"HTTP error {e.response.status_code}: {e.response.text}")
sys.exit(1)
except httpx.RequestError as e:
print(f"Error connecting to Llama Stack server: {e}")
sys.exit(1)
except Exception as e:
print(f"Error unregistering shield: {e}")
sys.exit(1)

View file

@ -43,6 +43,10 @@ class SafetyRouter(Safety):
logger.debug(f"SafetyRouter.register_shield: {shield_id}") logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def unregister_shield(self, identifier: str) -> None:
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
return await self.routing_table.unregister_shield(identifier)
async def run_shield( async def run_shield(
self, self,
shield_id: str, shield_id: str,

View file

@ -59,6 +59,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_vector_db(obj.identifier) return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference: elif api == Api.inference:
return await p.unregister_model(obj.identifier) return await p.unregister_model(obj.identifier)
elif api == Api.safety:
return await p.unregister_shield(obj.identifier)
elif api == Api.datasetio: elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier) return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime: elif api == Api.tool_runtime:

View file

@ -55,3 +55,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
) )
await self.register_object(shield) await self.register_object(shield)
return shield return shield
async def unregister_shield(self, identifier: str) -> None:
existing_shield = await self.get_shield(identifier)
if existing_shield is None:
raise ValueError(f"Shield '{identifier}' not found")
logger.info(f"Shield {identifier} was unregistered successfully.")
await self.unregister_object(existing_shield)

View file

@ -65,6 +65,8 @@ class ModelsProtocolPrivate(Protocol):
class ShieldsProtocolPrivate(Protocol): class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ... async def register_shield(self, shield: Shield) -> None: ...
async def unregister_shield(self, identifier: str) -> None: ...
class VectorDBsProtocolPrivate(Protocol): class VectorDBsProtocolPrivate(Protocol):
async def register_vector_db(self, vector_db: VectorDB) -> None: ... async def register_vector_db(self, vector_db: VectorDB) -> None: ...

View file

@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if not model_id: if not model_id:
raise ValueError("Llama Guard shield must have a model id") raise ValueError("Llama Guard shield must have a model id")
async def unregister_shield(self, identifier: str) -> None:
# LlamaGuard doesn't need to do anything special for unregistration
# The routing table handles the removal from the registry
pass
async def run_shield( async def run_shield(
self, self,
shield_id: str, shield_id: str,

View file

@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if shield.provider_resource_id != PROMPT_GUARD_MODEL: if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ") raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield( async def run_shield(
self, self,
shield_id: str, shield_id: str,

View file

@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock" f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
) )
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield( async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
) -> RunShieldResponse: ) -> RunShieldResponse:

View file

@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
if not shield.provider_resource_id: if not shield.provider_resource_id:
raise ValueError("Shield model not provided.") raise ValueError("Shield model not provided.")
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield( async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse: ) -> RunShieldResponse:

View file

@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
): ):
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}") logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield( async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse: ) -> RunShieldResponse:

View file

@ -8,6 +8,8 @@
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
@ -78,6 +80,9 @@ class SafetyImpl(Impl):
async def register_shield(self, shield: Shield): async def register_shield(self, shield: Shield):
return shield return shield
async def unregister_shield(self, shield_id: str):
return shield_id
class DatasetsImpl(Impl): class DatasetsImpl(Impl):
def __init__(self): def __init__(self):
@ -191,12 +196,42 @@ async def test_shields_routing_table(cached_disk_dist_registry):
await table.register_shield(shield_id="test-shield", provider_id="test_provider") await table.register_shield(shield_id="test-shield", provider_id="test_provider")
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider") await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
shields = await table.list_shields() shields = await table.list_shields()
assert len(shields.data) == 2 assert len(shields.data) == 2
shield_ids = {s.identifier for s in shields.data} shield_ids = {s.identifier for s in shields.data}
assert "test-shield" in shield_ids assert "test-shield" in shield_ids
assert "test-shield-2" in shield_ids assert "test-shield-2" in shield_ids
# Test get specific shield
test_shield = await table.get_shield(identifier="test-shield")
assert test_shield is not None
assert test_shield.identifier == "test-shield"
assert test_shield.provider_id == "test_provider"
assert test_shield.provider_resource_id == "test-shield"
assert test_shield.params == {}
# Test get non-existent shield - should raise ValueError with specific message
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
await table.get_shield(identifier="non-existent")
# Test unregistering shields
await table.unregister_shield(identifier="test-shield")
shields = await table.list_shields()
assert len(shields.data) == 1
shield_ids = {s.identifier for s in shields.data}
assert "test-shield" not in shield_ids
assert "test-shield-2" in shield_ids
# Unregister the remaining shield
await table.unregister_shield(identifier="test-shield-2")
shields = await table.list_shields()
assert len(shields.data) == 0
# Test unregistering non-existent shield - should raise ValueError with specific message
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
await table.unregister_shield(identifier="non-existent")
async def test_vectordbs_routing_table(cached_disk_dist_registry): async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})