feat: create unregister shield API endpoint in Llama Stack

This commit is contained in:
r3v5 2025-07-22 12:28:18 +01:00
parent 140ee7d337
commit 22e73494b0
No known key found for this signature in database
GPG key ID: 7758B9F272DE67D9
13 changed files with 132 additions and 1 deletions

View file

@ -83,3 +83,11 @@ class Shields(Protocol):
:returns: A Shield.
"""
...
@webmethod(route="/shields/{identifier:path}", method="DELETE")
async def unregister_shield(self, identifier: str) -> None:
"""Unregister a shield.
:param identifier: The identifier of the shield to unregister.
"""
...

View file

@ -43,6 +43,10 @@ class SafetyRouter(Safety):
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def unregister_shield(self, identifier: str) -> None:
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
return await self.routing_table.unregister_shield(identifier)
async def run_shield(
self,
shield_id: str,

View file

@ -60,6 +60,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.safety:
return await p.unregister_shield(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:

View file

@ -55,3 +55,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
)
await self.register_object(shield)
return shield
async def unregister_shield(self, identifier: str) -> None:
existing_shield = await self.get_shield(identifier)
await self.unregister_object(existing_shield)

View file

@ -65,6 +65,8 @@ class ModelsProtocolPrivate(Protocol):
class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...
async def unregister_shield(self, identifier: str) -> None: ...
class VectorDBsProtocolPrivate(Protocol):
async def register_vector_db(self, vector_db: VectorDB) -> None: ...

View file

@ -150,6 +150,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if not model_id:
raise ValueError("Llama Guard shield must have a model id")
async def unregister_shield(self, identifier: str) -> None:
# LlamaGuard doesn't need to do anything special for unregistration
# The routing table handles the removal from the registry
pass
async def run_shield(
self,
shield_id: str,

View file

@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if shield.provider_resource_id != PROMPT_GUARD_MODEL:
raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ")
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield(
self,
shield_id: str,

View file

@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
)
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] = None
) -> RunShieldResponse:

View file

@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse:

View file

@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
):
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
async def unregister_shield(self, identifier: str) -> None:
pass
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse: