feat: create unregister shield API endpoint

This commit is contained in:
r3v5 2025-07-22 12:28:18 +01:00
parent 9e6860b9cf
commit ac6034ed1e
No known key found for this signature in database
GPG key ID: 7758B9F272DE67D9
13 changed files with 135 additions and 1 deletions

View file

@ -8,6 +8,8 @@
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
@ -53,6 +55,9 @@ class SafetyImpl(Impl):
async def register_shield(self, shield: Shield):
return shield
async def unregister_shield(self, shield_id: str):
return shield_id
class VectorDBImpl(Impl):
def __init__(self):
@ -166,12 +171,42 @@ async def test_shields_routing_table(cached_disk_dist_registry):
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
shields = await table.list_shields()
assert len(shields.data) == 2
shield_ids = {s.identifier for s in shields.data}
assert "test-shield" in shield_ids
assert "test-shield-2" in shield_ids
# Test get specific shield
test_shield = await table.get_shield(identifier="test-shield")
assert test_shield is not None
assert test_shield.identifier == "test-shield"
assert test_shield.provider_id == "test_provider"
assert test_shield.provider_resource_id == "test-shield"
assert test_shield.params == {}
# Test get non-existent shield - should raise ValueError with specific message
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
await table.get_shield(identifier="non-existent")
# Test unregistering shields
await table.unregister_shield(identifier="test-shield")
shields = await table.list_shields()
assert len(shields.data) == 1
shield_ids = {s.identifier for s in shields.data}
assert "test-shield" not in shield_ids
assert "test-shield-2" in shield_ids
# Unregister the remaining shield
await table.unregister_shield(identifier="test-shield-2")
shields = await table.list_shields()
assert len(shields.data) == 0
# Test unregistering non-existent shield - should raise ValueError with specific message
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
await table.unregister_shield(identifier="non-existent")
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})