fix unit tests

This commit is contained in:
Ashwin Bharambe 2025-03-19 13:49:23 -07:00
parent b937a49436
commit 01fac67e33
6 changed files with 37 additions and 40 deletions

View file

@ -18,7 +18,7 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture
@pytest.fixture(scope="function")
async def kvstore():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_registry_acl.db")
@ -29,13 +29,14 @@ async def kvstore():
shutil.rmtree(temp_dir)
@pytest.fixture
@pytest.fixture(scope="function")
async def registry(kvstore):
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
return registry
@pytest.mark.asyncio
async def test_registry_cache_with_acl(registry):
model = ModelWithACL(
identifier="model-acl",
@ -68,7 +69,7 @@ async def test_registry_cache_with_acl(registry):
assert updated_cached.access_attributes.projects == ["project-x"]
assert updated_cached.access_attributes.teams is None
new_registry = CachedDiskDistributionRegistry(registry._kvstore)
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
await new_registry.initialize()
new_model = await new_registry.get("model", "model-acl")
@ -79,6 +80,7 @@ async def test_registry_cache_with_acl(registry):
assert new_model.access_attributes.teams is None
@pytest.mark.asyncio
async def test_registry_empty_acl(registry):
model = ModelWithACL(
identifier="model-empty-acl",
@ -118,24 +120,30 @@ async def test_registry_empty_acl(registry):
assert len(all_models) == 2
@pytest.mark.asyncio
async def test_registry_serialization(registry):
attributes = AccessAttributes(
roles=["admin", "researcher"],
teams=["ai-team", "ml-team"],
projects=["project-a", "project-b"],
namespaces=["prod", "staging"],
)
model = ModelWithACL(
identifier="model-serialize",
provider_id="test-provider",
provider_resource_id="model-resource",
model_type=ModelType.llm,
access_attributes=AccessAttributes(
roles=["admin", "researcher"],
teams=["ai-team", "ml-team"],
projects=["project-a", "project-b"],
namespaces=["prod", "staging"],
),
access_attributes=attributes,
)
await registry.register(model)
registry.cache.clear()
loaded_model = await registry.get("model", "model-serialize")
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
await new_registry.initialize()
loaded_model = await new_registry.get("model", "model-serialize")
assert loaded_model is not None
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]