mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
feat: create unregister shield API endpoint and CLI for shields management
This commit is contained in:
parent
537dc693ee
commit
93be3aa9b6
20 changed files with 563 additions and 1 deletions
34
docs/_static/llama-stack-spec.html
vendored
34
docs/_static/llama-stack-spec.html
vendored
|
@ -1452,6 +1452,40 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"delete": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK"
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Shields"
|
||||||
|
],
|
||||||
|
"description": "Unregister a shield.",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "identifier",
|
||||||
|
"in": "path",
|
||||||
|
"description": "The identifier of the shield to unregister.",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/telemetry/traces/{trace_id}/spans/{span_id}": {
|
"/v1/telemetry/traces/{trace_id}/spans/{span_id}": {
|
||||||
|
|
25
docs/_static/llama-stack-spec.yaml
vendored
25
docs/_static/llama-stack-spec.yaml
vendored
|
@ -999,6 +999,31 @@ paths:
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
delete:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: OK
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Shields
|
||||||
|
description: Unregister a shield.
|
||||||
|
parameters:
|
||||||
|
- name: identifier
|
||||||
|
in: path
|
||||||
|
description: >-
|
||||||
|
The identifier of the shield to unregister.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
/v1/telemetry/traces/{trace_id}/spans/{span_id}:
|
/v1/telemetry/traces/{trace_id}/spans/{span_id}:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
|
|
@ -79,3 +79,11 @@ class Shields(Protocol):
|
||||||
:returns: A Shield.
|
:returns: A Shield.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/shields/{identifier:path}", method="DELETE")
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
"""Unregister a shield.
|
||||||
|
|
||||||
|
:param identifier: The identifier of the shield to unregister.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -8,6 +8,7 @@ import argparse
|
||||||
|
|
||||||
from .download import Download
|
from .download import Download
|
||||||
from .model import ModelParser
|
from .model import ModelParser
|
||||||
|
from .shield import ShieldParser
|
||||||
from .stack import StackParser
|
from .stack import StackParser
|
||||||
from .stack.utils import print_subcommand_description
|
from .stack.utils import print_subcommand_description
|
||||||
from .verify_download import VerifyDownload
|
from .verify_download import VerifyDownload
|
||||||
|
@ -31,6 +32,7 @@ class LlamaCLIParser:
|
||||||
|
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
ModelParser.create(subparsers)
|
ModelParser.create(subparsers)
|
||||||
|
ShieldParser.create(subparsers)
|
||||||
StackParser.create(subparsers)
|
StackParser.create(subparsers)
|
||||||
Download.create(subparsers)
|
Download.create(subparsers)
|
||||||
VerifyDownload.create(subparsers)
|
VerifyDownload.create(subparsers)
|
||||||
|
|
9
llama_stack/cli/shield/__init__.py
Normal file
9
llama_stack/cli/shield/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .shield import ShieldParser
|
||||||
|
|
||||||
|
__all__ = ["ShieldParser"]
|
93
llama_stack/cli/shield/describe.py
Normal file
93
llama_stack/cli/shield/describe.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.cli.table import print_table
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldDescribe(Subcommand):
|
||||||
|
"""Show details about a shield"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"describe",
|
||||||
|
prog="llama shield describe",
|
||||||
|
description="Show details about a shield",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_shield_describe_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
"shield_id",
|
||||||
|
type=str,
|
||||||
|
help="The identifier of the shield to describe",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8321",
|
||||||
|
help="URL of the Llama Stack server (default: http://localhost:8321)",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--output-format",
|
||||||
|
choices=["table", "json"],
|
||||||
|
default="table",
|
||||||
|
help="Output format (default: table)",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_shield_describe_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
try:
|
||||||
|
response = httpx.get(f"{args.url}/v1/shields/{args.shield_id}")
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
shield = response.json()
|
||||||
|
|
||||||
|
if args.output_format == "json":
|
||||||
|
print(json.dumps(shield, indent=2))
|
||||||
|
return
|
||||||
|
|
||||||
|
headers = ["Property", "Value"]
|
||||||
|
|
||||||
|
shield_id = shield.get("identifier", shield.get("shield_id", args.shield_id))
|
||||||
|
provider_id = shield.get("provider_id", "<Not Set>")
|
||||||
|
provider_shield_id = shield.get("provider_resource_id", shield.get("provider_shield_id", "<Not Set>"))
|
||||||
|
resource_type = shield.get("type", "shield")
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
("Shield ID", shield_id),
|
||||||
|
("Provider ID", provider_id),
|
||||||
|
("Provider Shield ID", provider_shield_id),
|
||||||
|
("Resource Type", resource_type),
|
||||||
|
]
|
||||||
|
|
||||||
|
if shield.get("params"):
|
||||||
|
rows.append(("Parameters", json.dumps(shield["params"], indent=2)))
|
||||||
|
else:
|
||||||
|
rows.append(("Parameters", "<None>"))
|
||||||
|
|
||||||
|
print_table(
|
||||||
|
rows,
|
||||||
|
headers,
|
||||||
|
separate_rows=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code == 400 or e.response.status_code == 404:
|
||||||
|
print(f"Shield '{args.shield_id}' not found.")
|
||||||
|
else:
|
||||||
|
print(f"HTTP error {e.response.status_code}: {e.response.text}")
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
print(f"Error connecting to Llama Stack server: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error describing shield: {e}")
|
90
llama_stack/cli/shield/list.py
Normal file
90
llama_stack/cli/shield/list.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.cli.table import print_table
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldList(Subcommand):
|
||||||
|
"""List available shields"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"list",
|
||||||
|
prog="llama shield list",
|
||||||
|
description="Show available shields",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_shield_list_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8321",
|
||||||
|
help="URL of the Llama Stack server (default: http://localhost:8321)",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--output-format",
|
||||||
|
choices=["table", "json"],
|
||||||
|
default="table",
|
||||||
|
help="Output format (default: table)",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_shield_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
try:
|
||||||
|
response = httpx.get(f"{args.url}/v1/shields")
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
shields = data.get("data", [])
|
||||||
|
|
||||||
|
if args.output_format == "json":
|
||||||
|
print(json.dumps(shields, indent=2))
|
||||||
|
return
|
||||||
|
|
||||||
|
if not shields:
|
||||||
|
print("No shields found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
headers = ["Shield ID", "Provider ID", "Provider Shield ID", "Parameters"]
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for shield in shields:
|
||||||
|
params_str = ""
|
||||||
|
if shield.get("params"):
|
||||||
|
params_str = json.dumps(shield["params"], separators=(",", ":"))
|
||||||
|
if len(params_str) > 50:
|
||||||
|
params_str = params_str[:47] + "..."
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
shield.get("identifier", shield.get("shield_id", "-")),
|
||||||
|
shield.get("provider_id", "-"),
|
||||||
|
shield.get("provider_resource_id", shield.get("provider_shield_id", "-")),
|
||||||
|
params_str or "-",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print_table(
|
||||||
|
rows,
|
||||||
|
headers,
|
||||||
|
separate_rows=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
print(f"Error connecting to Llama Stack server: {e}")
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
print(f"HTTP error {e.response.status_code}: {e.response.text}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error listing shields: {e}")
|
103
llama_stack/cli/shield/register.py
Normal file
103
llama_stack/cli/shield/register.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldRegister(Subcommand):
|
||||||
|
"""Register a new shield"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"register",
|
||||||
|
prog="llama shield register",
|
||||||
|
description="Register a new shield",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_shield_register_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
"shield_id",
|
||||||
|
type=str,
|
||||||
|
help="The identifier for the shield to register",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--provider-id",
|
||||||
|
type=str,
|
||||||
|
help="The provider ID for the shield",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--provider-shield-id",
|
||||||
|
type=str,
|
||||||
|
help="The provider-specific shield identifier",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--params",
|
||||||
|
type=str,
|
||||||
|
help='Shield parameters as JSON string (e.g., \'{"key": "value"}\')',
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8321",
|
||||||
|
help="URL of the Llama Stack server (default: http://localhost:8321)",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_shield_register_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
try:
|
||||||
|
params = None
|
||||||
|
if args.params:
|
||||||
|
try:
|
||||||
|
params = json.loads(args.params)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Error parsing parameters JSON: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"shield_id": args.shield_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.provider_id:
|
||||||
|
payload["provider_id"] = args.provider_id
|
||||||
|
if args.provider_shield_id:
|
||||||
|
payload["provider_shield_id"] = args.provider_shield_id
|
||||||
|
if params:
|
||||||
|
payload["params"] = params
|
||||||
|
|
||||||
|
response = httpx.post(f"{args.url}/v1/shields", json=payload, headers={"Content-Type": "application/json"})
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
shield = response.json()
|
||||||
|
|
||||||
|
print(f" Shield '{shield.get('identifier', args.shield_id)}' registered successfully!")
|
||||||
|
print(f" Provider ID: {shield.get('provider_id', '<Not Set>')}")
|
||||||
|
print(
|
||||||
|
f" Provider Shield ID: {shield.get('provider_resource_id', shield.get('provider_shield_id', '<Not Set>'))}"
|
||||||
|
)
|
||||||
|
if shield.get("params"):
|
||||||
|
print(f" Parameters: {json.dumps(shield['params'], indent=2)}")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code == 400 and "already exists" in e.response.text:
|
||||||
|
print(f"Shield '{args.shield_id}' already exists.")
|
||||||
|
else:
|
||||||
|
print(f"HTTP error {e.response.status_code}: {e.response.text}")
|
||||||
|
sys.exit(1)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
print(f"Error connecting to Llama Stack server: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error registering shield: {e}")
|
||||||
|
sys.exit(1)
|
40
llama_stack/cli/shield/shield.py
Normal file
40
llama_stack/cli/shield/shield.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
from .describe import ShieldDescribe
|
||||||
|
from .list import ShieldList
|
||||||
|
from .register import ShieldRegister
|
||||||
|
from .unregister import ShieldUnregister
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldParser(Subcommand):
|
||||||
|
"""Parser for shield commands"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"shield",
|
||||||
|
prog="llama shield",
|
||||||
|
description="Manage safety shields",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||||
|
|
||||||
|
subparsers = self.parser.add_subparsers(title="shield_subcommands")
|
||||||
|
|
||||||
|
# Add shield sub-commands
|
||||||
|
ShieldList.create(subparsers)
|
||||||
|
ShieldRegister.create(subparsers)
|
||||||
|
ShieldDescribe.create(subparsers)
|
||||||
|
ShieldUnregister.create(subparsers)
|
||||||
|
|
||||||
|
print_subcommand_description(self.parser, subparsers)
|
91
llama_stack/cli/shield/unregister.py
Normal file
91
llama_stack/cli/shield/unregister.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldUnregister(Subcommand):
|
||||||
|
"""Unregister a shield"""
|
||||||
|
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"unregister",
|
||||||
|
prog="llama shield unregister",
|
||||||
|
description="Unregister a shield",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_shield_unregister_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
"shield_id",
|
||||||
|
type=str,
|
||||||
|
help="The identifier of the shield to unregister",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8321",
|
||||||
|
help="URL of the Llama Stack server (default: http://localhost:8321)",
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--force",
|
||||||
|
action="store_true",
|
||||||
|
help="Force unregister without confirmation",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_shield_unregister_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
response = httpx.get(f"{args.url}/v1/shields/{args.shield_id}")
|
||||||
|
response.raise_for_status()
|
||||||
|
shield = response.json()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code == 400 or e.response.status_code == 404:
|
||||||
|
print(f"Shield '{args.shield_id}' not found.")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not args.force:
|
||||||
|
shield_id = shield.get("identifier", shield.get("shield_id", args.shield_id))
|
||||||
|
provider_id = shield.get("provider_id", "<Not Set>")
|
||||||
|
provider_shield_id = shield.get("provider_resource_id", shield.get("provider_shield_id", "<Not Set>"))
|
||||||
|
|
||||||
|
print("Shield to unregister:")
|
||||||
|
print(f" - Shield ID: {shield_id}")
|
||||||
|
print(f" - Provider ID: {provider_id}")
|
||||||
|
print(f" - Provider Shield ID: {provider_shield_id}")
|
||||||
|
|
||||||
|
response_input = input(f"\nAre you sure you want to unregister shield '{args.shield_id}'? (y/N): ")
|
||||||
|
if response_input.lower() not in ["y", "yes"]:
|
||||||
|
print("Unregister cancelled.")
|
||||||
|
return
|
||||||
|
|
||||||
|
response = httpx.delete(f"{args.url}/v1/shields/{args.shield_id}")
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
print(f"Shield '{args.shield_id}' unregistered successfully!")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code == 400 or e.response.status_code == 404:
|
||||||
|
print(f"Shield '{args.shield_id}' not found.")
|
||||||
|
else:
|
||||||
|
print(f"HTTP error {e.response.status_code}: {e.response.text}")
|
||||||
|
sys.exit(1)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
print(f"Error connecting to Llama Stack server: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error unregistering shield: {e}")
|
||||||
|
sys.exit(1)
|
|
@ -43,6 +43,10 @@ class SafetyRouter(Safety):
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
||||||
|
return await self.routing_table.unregister_shield(identifier)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
|
|
|
@ -59,6 +59,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
return await p.unregister_vector_db(obj.identifier)
|
return await p.unregister_vector_db(obj.identifier)
|
||||||
elif api == Api.inference:
|
elif api == Api.inference:
|
||||||
return await p.unregister_model(obj.identifier)
|
return await p.unregister_model(obj.identifier)
|
||||||
|
elif api == Api.safety:
|
||||||
|
return await p.unregister_shield(obj.identifier)
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
return await p.unregister_dataset(obj.identifier)
|
return await p.unregister_dataset(obj.identifier)
|
||||||
elif api == Api.tool_runtime:
|
elif api == Api.tool_runtime:
|
||||||
|
|
|
@ -55,3 +55,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
)
|
)
|
||||||
await self.register_object(shield)
|
await self.register_object(shield)
|
||||||
return shield
|
return shield
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
existing_shield = await self.get_shield(identifier)
|
||||||
|
if existing_shield is None:
|
||||||
|
raise ValueError(f"Shield '{identifier}' not found")
|
||||||
|
logger.info(f"Shield {identifier} was unregistered successfully.")
|
||||||
|
await self.unregister_object(existing_shield)
|
||||||
|
|
|
@ -51,6 +51,8 @@ class ModelsProtocolPrivate(Protocol):
|
||||||
class ShieldsProtocolPrivate(Protocol):
|
class ShieldsProtocolPrivate(Protocol):
|
||||||
async def register_shield(self, shield: Shield) -> None: ...
|
async def register_shield(self, shield: Shield) -> None: ...
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsProtocolPrivate(Protocol):
|
class VectorDBsProtocolPrivate(Protocol):
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
|
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
|
||||||
|
|
|
@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
# The model will be validated during runtime when making inference calls
|
# The model will be validated during runtime when making inference calls
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
# LlamaGuard doesn't need to do anything special for unregistration
|
||||||
|
# The routing table handles the removal from the registry
|
||||||
|
pass
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
|
|
|
@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
|
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
|
||||||
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
|
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
|
|
|
@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
|
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
|
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
|
|
@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
if not shield.provider_resource_id:
|
if not shield.provider_resource_id:
|
||||||
raise ValueError("Shield model not provided.")
|
raise ValueError("Shield model not provided.")
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
|
|
@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
||||||
):
|
):
|
||||||
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||||
|
|
||||||
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
|
|
@ -8,6 +8,8 @@
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
|
@ -53,6 +55,9 @@ class SafetyImpl(Impl):
|
||||||
async def register_shield(self, shield: Shield):
|
async def register_shield(self, shield: Shield):
|
||||||
return shield
|
return shield
|
||||||
|
|
||||||
|
async def unregister_shield(self, shield_id: str):
|
||||||
|
return shield_id
|
||||||
|
|
||||||
|
|
||||||
class DatasetsImpl(Impl):
|
class DatasetsImpl(Impl):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -166,12 +171,42 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
||||||
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
|
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
|
||||||
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
|
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
|
||||||
shields = await table.list_shields()
|
shields = await table.list_shields()
|
||||||
|
|
||||||
assert len(shields.data) == 2
|
assert len(shields.data) == 2
|
||||||
|
|
||||||
shield_ids = {s.identifier for s in shields.data}
|
shield_ids = {s.identifier for s in shields.data}
|
||||||
assert "test-shield" in shield_ids
|
assert "test-shield" in shield_ids
|
||||||
assert "test-shield-2" in shield_ids
|
assert "test-shield-2" in shield_ids
|
||||||
|
|
||||||
|
# Test get specific shield
|
||||||
|
test_shield = await table.get_shield(identifier="test-shield")
|
||||||
|
assert test_shield is not None
|
||||||
|
assert test_shield.identifier == "test-shield"
|
||||||
|
assert test_shield.provider_id == "test_provider"
|
||||||
|
assert test_shield.provider_resource_id == "test-shield"
|
||||||
|
assert test_shield.params == {}
|
||||||
|
|
||||||
|
# Test get non-existent shield - should raise ValueError with specific message
|
||||||
|
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
|
||||||
|
await table.get_shield(identifier="non-existent")
|
||||||
|
|
||||||
|
# Test unregistering shields
|
||||||
|
await table.unregister_shield(identifier="test-shield")
|
||||||
|
shields = await table.list_shields()
|
||||||
|
|
||||||
|
assert len(shields.data) == 1
|
||||||
|
shield_ids = {s.identifier for s in shields.data}
|
||||||
|
assert "test-shield" not in shield_ids
|
||||||
|
assert "test-shield-2" in shield_ids
|
||||||
|
|
||||||
|
# Unregister the remaining shield
|
||||||
|
await table.unregister_shield(identifier="test-shield-2")
|
||||||
|
shields = await table.list_shields()
|
||||||
|
assert len(shields.data) == 0
|
||||||
|
|
||||||
|
# Test unregistering non-existent shield - should raise ValueError with specific message
|
||||||
|
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
|
||||||
|
await table.unregister_shield(identifier="non-existent")
|
||||||
|
|
||||||
|
|
||||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue