mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-25 21:57:45 +00:00
Merge ac6034ed1e
into cbe89d2bdd
This commit is contained in:
commit
36a5379029
13 changed files with 135 additions and 1 deletions
34
docs/_static/llama-stack-spec.html
vendored
34
docs/_static/llama-stack-spec.html
vendored
|
@ -1452,6 +1452,40 @@
|
|||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Shields"
|
||||
],
|
||||
"description": "Unregister a shield.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "identifier",
|
||||
"in": "path",
|
||||
"description": "The identifier of the shield to unregister.",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/v1/telemetry/traces/{trace_id}/spans/{span_id}": {
|
||||
|
|
25
docs/_static/llama-stack-spec.yaml
vendored
25
docs/_static/llama-stack-spec.yaml
vendored
|
@ -999,6 +999,31 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
delete:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Shields
|
||||
description: Unregister a shield.
|
||||
parameters:
|
||||
- name: identifier
|
||||
in: path
|
||||
description: >-
|
||||
The identifier of the shield to unregister.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/telemetry/traces/{trace_id}/spans/{span_id}:
|
||||
get:
|
||||
responses:
|
||||
|
|
|
@ -79,3 +79,11 @@ class Shields(Protocol):
|
|||
:returns: A Shield.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE")
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
"""Unregister a shield.
|
||||
|
||||
:param identifier: The identifier of the shield to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -43,6 +43,10 @@ class SafetyRouter(Safety):
|
|||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
||||
return await self.routing_table.unregister_shield(identifier)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
|
|
|
@ -59,6 +59,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
return await p.unregister_vector_db(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.safety:
|
||||
return await p.unregister_shield(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
return await p.unregister_dataset(obj.identifier)
|
||||
elif api == Api.tool_runtime:
|
||||
|
|
|
@ -55,3 +55,10 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
)
|
||||
await self.register_object(shield)
|
||||
return shield
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
existing_shield = await self.get_shield(identifier)
|
||||
if existing_shield is None:
|
||||
raise ValueError(f"Shield '{identifier}' not found")
|
||||
logger.info(f"Shield {identifier} was unregistered successfully.")
|
||||
await self.unregister_object(existing_shield)
|
||||
|
|
|
@ -62,6 +62,8 @@ class ModelsProtocolPrivate(Protocol):
|
|||
class ShieldsProtocolPrivate(Protocol):
|
||||
async def register_shield(self, shield: Shield) -> None: ...
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None: ...
|
||||
|
||||
|
||||
class VectorDBsProtocolPrivate(Protocol):
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None: ...
|
||||
|
|
|
@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
# The model will be validated during runtime when making inference calls
|
||||
pass
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
# LlamaGuard doesn't need to do anything special for unregistration
|
||||
# The routing table handles the removal from the registry
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
|
|
|
@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
|
||||
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
|
|
|
@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
|
||||
)
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
|
|
|
@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
if not shield.provider_resource_id:
|
||||
raise ValueError("Shield model not provided.")
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||
) -> RunShieldResponse:
|
||||
|
|
|
@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
|||
):
|
||||
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||
) -> RunShieldResponse:
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.datatypes import Api
|
||||
|
@ -78,6 +80,9 @@ class SafetyImpl(Impl):
|
|||
async def register_shield(self, shield: Shield):
|
||||
return shield
|
||||
|
||||
async def unregister_shield(self, shield_id: str):
|
||||
return shield_id
|
||||
|
||||
|
||||
class DatasetsImpl(Impl):
|
||||
def __init__(self):
|
||||
|
@ -191,12 +196,42 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
|||
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
|
||||
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
|
||||
shields = await table.list_shields()
|
||||
|
||||
assert len(shields.data) == 2
|
||||
|
||||
shield_ids = {s.identifier for s in shields.data}
|
||||
assert "test-shield" in shield_ids
|
||||
assert "test-shield-2" in shield_ids
|
||||
|
||||
# Test get specific shield
|
||||
test_shield = await table.get_shield(identifier="test-shield")
|
||||
assert test_shield is not None
|
||||
assert test_shield.identifier == "test-shield"
|
||||
assert test_shield.provider_id == "test_provider"
|
||||
assert test_shield.provider_resource_id == "test-shield"
|
||||
assert test_shield.params == {}
|
||||
|
||||
# Test get non-existent shield - should raise ValueError with specific message
|
||||
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
|
||||
await table.get_shield(identifier="non-existent")
|
||||
|
||||
# Test unregistering shields
|
||||
await table.unregister_shield(identifier="test-shield")
|
||||
shields = await table.list_shields()
|
||||
|
||||
assert len(shields.data) == 1
|
||||
shield_ids = {s.identifier for s in shields.data}
|
||||
assert "test-shield" not in shield_ids
|
||||
assert "test-shield-2" in shield_ids
|
||||
|
||||
# Unregister the remaining shield
|
||||
await table.unregister_shield(identifier="test-shield-2")
|
||||
shields = await table.list_shields()
|
||||
assert len(shields.data) == 0
|
||||
|
||||
# Test unregistering non-existent shield - should raise ValueError with specific message
|
||||
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
|
||||
await table.unregister_shield(identifier="non-existent")
|
||||
|
||||
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue