mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
Merge ac6034ed1e
into cbe89d2bdd
This commit is contained in:
commit
36a5379029
13 changed files with 135 additions and 1 deletions
|
@ -8,6 +8,8 @@
|
|||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.datatypes import Api
|
||||
|
@ -78,6 +80,9 @@ class SafetyImpl(Impl):
|
|||
async def register_shield(self, shield: Shield):
|
||||
return shield
|
||||
|
||||
async def unregister_shield(self, shield_id: str):
|
||||
return shield_id
|
||||
|
||||
|
||||
class DatasetsImpl(Impl):
|
||||
def __init__(self):
|
||||
|
@ -191,12 +196,42 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
|||
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
|
||||
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
|
||||
shields = await table.list_shields()
|
||||
|
||||
assert len(shields.data) == 2
|
||||
|
||||
shield_ids = {s.identifier for s in shields.data}
|
||||
assert "test-shield" in shield_ids
|
||||
assert "test-shield-2" in shield_ids
|
||||
|
||||
# Test get specific shield
|
||||
test_shield = await table.get_shield(identifier="test-shield")
|
||||
assert test_shield is not None
|
||||
assert test_shield.identifier == "test-shield"
|
||||
assert test_shield.provider_id == "test_provider"
|
||||
assert test_shield.provider_resource_id == "test-shield"
|
||||
assert test_shield.params == {}
|
||||
|
||||
# Test get non-existent shield - should raise ValueError with specific message
|
||||
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
|
||||
await table.get_shield(identifier="non-existent")
|
||||
|
||||
# Test unregistering shields
|
||||
await table.unregister_shield(identifier="test-shield")
|
||||
shields = await table.list_shields()
|
||||
|
||||
assert len(shields.data) == 1
|
||||
shield_ids = {s.identifier for s in shields.data}
|
||||
assert "test-shield" not in shield_ids
|
||||
assert "test-shield-2" in shield_ids
|
||||
|
||||
# Unregister the remaining shield
|
||||
await table.unregister_shield(identifier="test-shield-2")
|
||||
shields = await table.list_shields()
|
||||
assert len(shields.data) == 0
|
||||
|
||||
# Test unregistering non-existent shield - should raise ValueError with specific message
|
||||
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
|
||||
await table.unregister_shield(identifier="non-existent")
|
||||
|
||||
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue