chore: Add fixtures to conftest.py (#2067)

Add fixtures for SqliteKVStore, DiskDistributionRegistry and
CachedDiskDistributionRegistry. And use them in tests that had all been
duplicating similar setups.

## Test Plan
unit tests continue to run

Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
Derek Higgins 2025-05-06 12:57:48 +01:00 committed by GitHub
parent 4597145011
commit 2e807b38cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 122 additions and 167 deletions

5
tests/unit/__init__.py Normal file
View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

9
tests/unit/conftest.py Normal file
View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# We need to import the fixtures here so that pytest can find them
# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed
from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401

View file

@ -26,20 +26,6 @@ from llama_stack.distribution.routers.routing_tables import (
ToolGroupsRoutingTable, ToolGroupsRoutingTable,
VectorDBsRoutingTable, VectorDBsRoutingTable,
) )
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture
async def dist_registry(tmp_path):
db_path = tmp_path / "test_kv.db"
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
yield registry
class Impl: class Impl:
@ -136,8 +122,8 @@ class ToolGroupsImpl(Impl):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_models_routing_table(dist_registry): async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, dist_registry) table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
await table.initialize() await table.initialize()
# Register multiple models and verify listing # Register multiple models and verify listing
@ -178,8 +164,8 @@ async def test_models_routing_table(dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_shields_routing_table(dist_registry): async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, dist_registry) table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
await table.initialize() await table.initialize()
# Register multiple shields and verify listing # Register multiple shields and verify listing
@ -194,11 +180,11 @@ async def test_shields_routing_table(dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_vectordbs_routing_table(dist_registry): async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, dist_registry) table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
await table.initialize() await table.initialize()
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, dist_registry) m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
await m_table.initialize() await m_table.initialize()
await m_table.register_model( await m_table.register_model(
model_id="test-model", model_id="test-model",
@ -224,8 +210,8 @@ async def test_vectordbs_routing_table(dist_registry):
assert len(vector_dbs.data) == 0 assert len(vector_dbs.data) == 0
async def test_datasets_routing_table(dist_registry): async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, dist_registry) table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
await table.initialize() await table.initialize()
# Register multiple datasets and verify listing # Register multiple datasets and verify listing
@ -250,8 +236,8 @@ async def test_datasets_routing_table(dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scoring_functions_routing_table(dist_registry): async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, dist_registry) table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
await table.initialize() await table.initialize()
# Register multiple scoring functions and verify listing # Register multiple scoring functions and verify listing
@ -276,8 +262,8 @@ async def test_scoring_functions_routing_table(dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_benchmarks_routing_table(dist_registry): async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, dist_registry) table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
await table.initialize() await table.initialize()
# Register multiple benchmarks and verify listing # Register multiple benchmarks and verify listing
@ -294,8 +280,8 @@ async def test_benchmarks_routing_table(dist_registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_groups_routing_table(dist_registry): async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, dist_registry) table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
await table.initialize() await table.initialize()
# Register multiple tool groups and verify listing # Register multiple tool groups and verify listing

34
tests/unit/fixtures.py Normal file
View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture(scope="function")
async def sqlite_kvstore(tmp_path):
db_path = tmp_path / "test_kv.db"
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
yield kvstore
@pytest.fixture(scope="function")
async def disk_dist_registry(sqlite_kvstore):
registry = DiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()
yield registry
@pytest.fixture(scope="function")
async def cached_disk_dist_registry(sqlite_kvstore):
registry = CachedDiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()
yield registry

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
import shutil
import tempfile
import uuid import uuid
from datetime import datetime from datetime import datetime
from unittest.mock import patch from unittest.mock import patch
@ -17,20 +14,12 @@ from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture @pytest.fixture
async def test_setup(): async def test_setup(sqlite_kvstore):
temp_dir = tempfile.mkdtemp() agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore)
db_path = os.path.join(temp_dir, "test_persistence_access_control.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=kvstore)
yield agent_persistence yield agent_persistence
shutil.rmtree(temp_dir)
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -4,10 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.inference import Model from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
@ -20,28 +18,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture
def config():
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
if os.path.exists(config.db_path):
os.remove(config.db_path)
return config
@pytest_asyncio.fixture(scope="function")
async def registry(config):
registry = DiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize()
return registry
@pytest_asyncio.fixture(scope="function")
async def cached_registry(config):
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize()
return registry
@pytest.fixture @pytest.fixture
def sample_vector_db(): def sample_vector_db():
return VectorDB( return VectorDB(
@ -63,41 +39,42 @@ def sample_model():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry_initialization(registry): async def test_registry_initialization(disk_dist_registry):
# Test empty registry # Test empty registry
result = await registry.get("nonexistent", "nonexistent") result = await disk_dist_registry.get("nonexistent", "nonexistent")
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_registration(registry, sample_vector_db, sample_model): async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
print(f"Registering {sample_vector_db}") print(f"Registering {sample_vector_db}")
await registry.register(sample_vector_db) await disk_dist_registry.register(sample_vector_db)
print(f"Registering {sample_model}") print(f"Registering {sample_model}")
await registry.register(sample_model) await disk_dist_registry.register(sample_model)
print("Getting vector_db") print("Getting vector_db")
result_vector_db = await registry.get("vector_db", "test_vector_db") result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db")
assert result_vector_db is not None assert result_vector_db is not None
assert result_vector_db.identifier == sample_vector_db.identifier assert result_vector_db.identifier == sample_vector_db.identifier
assert result_vector_db.embedding_model == sample_vector_db.embedding_model assert result_vector_db.embedding_model == sample_vector_db.embedding_model
assert result_vector_db.provider_id == sample_vector_db.provider_id assert result_vector_db.provider_id == sample_vector_db.provider_id
result_model = await registry.get("model", "test_model") result_model = await disk_dist_registry.get("model", "test_model")
assert result_model is not None assert result_model is not None
assert result_model.identifier == sample_model.identifier assert result_model.identifier == sample_model.identifier
assert result_model.provider_id == sample_model.provider_id assert result_model.provider_id == sample_model.provider_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cached_registry_initialization(config, sample_vector_db, sample_model): async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
# First populate the disk registry # First populate the disk registry
disk_registry = DiskDistributionRegistry(await kvstore_impl(config)) disk_registry = DiskDistributionRegistry(sqlite_kvstore)
await disk_registry.initialize() await disk_registry.initialize()
await disk_registry.register(sample_vector_db) await disk_registry.register(sample_vector_db)
await disk_registry.register(sample_model) await disk_registry.register(sample_model)
# Test cached version loads from disk # Test cached version loads from disk
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) db_path = sqlite_kvstore.db_path
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
await cached_registry.initialize() await cached_registry.initialize()
result_vector_db = await cached_registry.get("vector_db", "test_vector_db") result_vector_db = await cached_registry.get("vector_db", "test_vector_db")
@ -109,10 +86,7 @@ async def test_cached_registry_initialization(config, sample_vector_db, sample_m
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cached_registry_updates(config): async def test_cached_registry_updates(cached_disk_dist_registry):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
new_vector_db = VectorDB( new_vector_db = VectorDB(
identifier="test_vector_db_2", identifier="test_vector_db_2",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
@ -120,16 +94,17 @@ async def test_cached_registry_updates(config):
provider_resource_id="test_vector_db_2", provider_resource_id="test_vector_db_2",
provider_id="baz", provider_id="baz",
) )
await cached_registry.register(new_vector_db) await cached_disk_dist_registry.register(new_vector_db)
# Verify in cache # Verify in cache
result_vector_db = await cached_registry.get("vector_db", "test_vector_db_2") result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None assert result_vector_db is not None
assert result_vector_db.identifier == new_vector_db.identifier assert result_vector_db.identifier == new_vector_db.identifier
assert result_vector_db.provider_id == new_vector_db.provider_id assert result_vector_db.provider_id == new_vector_db.provider_id
# Verify persisted to disk # Verify persisted to disk
new_registry = DiskDistributionRegistry(await kvstore_impl(config)) db_path = cached_disk_dist_registry.kvstore.db_path
new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)))
await new_registry.initialize() await new_registry.initialize()
result_vector_db = await new_registry.get("vector_db", "test_vector_db_2") result_vector_db = await new_registry.get("vector_db", "test_vector_db_2")
assert result_vector_db is not None assert result_vector_db is not None
@ -138,10 +113,7 @@ async def test_cached_registry_updates(config):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_duplicate_provider_registration(config): async def test_duplicate_provider_registration(cached_disk_dist_registry):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize()
original_vector_db = VectorDB( original_vector_db = VectorDB(
identifier="test_vector_db_2", identifier="test_vector_db_2",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
@ -149,7 +121,7 @@ async def test_duplicate_provider_registration(config):
provider_resource_id="test_vector_db_2", provider_resource_id="test_vector_db_2",
provider_id="baz", provider_id="baz",
) )
await cached_registry.register(original_vector_db) await cached_disk_dist_registry.register(original_vector_db)
duplicate_vector_db = VectorDB( duplicate_vector_db = VectorDB(
identifier="test_vector_db_2", identifier="test_vector_db_2",
@ -158,18 +130,16 @@ async def test_duplicate_provider_registration(config):
provider_resource_id="test_vector_db_2", provider_resource_id="test_vector_db_2",
provider_id="baz", # Same provider_id provider_id="baz", # Same provider_id
) )
await cached_registry.register(duplicate_vector_db) await cached_disk_dist_registry.register(duplicate_vector_db)
result = await cached_registry.get("vector_db", "test_vector_db_2") result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result is not None assert result is not None
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_all_objects(config): async def test_get_all_objects(cached_disk_dist_registry):
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) # Create multiple test banks
await cached_registry.initialize()
# Create multiple test banks # Create multiple test banks
test_vector_dbs = [ test_vector_dbs = [
VectorDB( VectorDB(
@ -184,10 +154,10 @@ async def test_get_all_objects(config):
# Register all vector_dbs # Register all vector_dbs
for vector_db in test_vector_dbs: for vector_db in test_vector_dbs:
await cached_registry.register(vector_db) await cached_disk_dist_registry.register(vector_db)
# Test get_all retrieval # Test get_all retrieval
all_results = await cached_registry.get_all() all_results = await cached_disk_dist_registry.get_all()
assert len(all_results) == 3 assert len(all_results) == 3
# Verify each vector_db was stored correctly # Verify each vector_db was stored correctly
@ -201,9 +171,7 @@ async def test_get_all_objects(config):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_registry_values_error_handling(config): async def test_parse_registry_values_error_handling(sqlite_kvstore):
kvstore = await kvstore_impl(config)
valid_db = VectorDB( valid_db = VectorDB(
identifier="valid_vector_db", identifier="valid_vector_db",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
@ -212,16 +180,18 @@ async def test_parse_registry_values_error_handling(config):
provider_id="test-provider", provider_id="test-provider",
) )
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()) await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json") await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
'{"type": "vector_db", "identifier": "missing_fields"}', '{"type": "vector_db", "identifier": "missing_fields"}',
) )
test_registry = DiskDistributionRegistry(kvstore) test_registry = DiskDistributionRegistry(sqlite_kvstore)
await test_registry.initialize() await test_registry.initialize()
# Get all objects, which should only return the valid one # Get all objects, which should only return the valid one
@ -240,9 +210,7 @@ async def test_parse_registry_values_error_handling(config):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cached_registry_error_handling(config): async def test_cached_registry_error_handling(sqlite_kvstore):
kvstore = await kvstore_impl(config)
valid_db = VectorDB( valid_db = VectorDB(
identifier="valid_cached_db", identifier="valid_cached_db",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
@ -251,14 +219,16 @@ async def test_cached_registry_error_handling(config):
provider_id="test-provider", provider_id="test-provider",
) )
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()) await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
)
await kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"), KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string '{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
) )
cached_registry = CachedDiskDistributionRegistry(kvstore) cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore)
await cached_registry.initialize() await cached_registry.initialize()
all_objects = await cached_registry.get_all() all_objects = await cached_registry.get_all()

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
import shutil
import tempfile
import pytest import pytest
@ -14,30 +11,10 @@ from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelWithACL from llama_stack.distribution.datatypes import ModelWithACL
from llama_stack.distribution.server.auth_providers import AccessAttributes from llama_stack.distribution.server.auth_providers import AccessAttributes
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture(scope="function")
async def kvstore():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_registry_acl.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
yield kvstore
shutil.rmtree(temp_dir)
@pytest.fixture(scope="function")
async def registry(kvstore):
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
return registry
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry_cache_with_acl(registry): async def test_registry_cache_with_acl(cached_disk_dist_registry):
model = ModelWithACL( model = ModelWithACL(
identifier="model-acl", identifier="model-acl",
provider_id="test-provider", provider_id="test-provider",
@ -46,30 +23,30 @@ async def test_registry_cache_with_acl(registry):
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]), access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
) )
success = await registry.register(model) success = await cached_disk_dist_registry.register(model)
assert success assert success
cached_model = registry.get_cached("model", "model-acl") cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
assert cached_model is not None assert cached_model is not None
assert cached_model.identifier == "model-acl" assert cached_model.identifier == "model-acl"
assert cached_model.access_attributes.roles == ["admin"] assert cached_model.access_attributes.roles == ["admin"]
assert cached_model.access_attributes.teams == ["ai-team"] assert cached_model.access_attributes.teams == ["ai-team"]
fetched_model = await registry.get("model", "model-acl") fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
assert fetched_model is not None assert fetched_model is not None
assert fetched_model.identifier == "model-acl" assert fetched_model.identifier == "model-acl"
assert fetched_model.access_attributes.roles == ["admin"] assert fetched_model.access_attributes.roles == ["admin"]
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"]) model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
await registry.update(model) await cached_disk_dist_registry.update(model)
updated_cached = registry.get_cached("model", "model-acl") updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
assert updated_cached is not None assert updated_cached is not None
assert updated_cached.access_attributes.roles == ["admin", "user"] assert updated_cached.access_attributes.roles == ["admin", "user"]
assert updated_cached.access_attributes.projects == ["project-x"] assert updated_cached.access_attributes.projects == ["project-x"]
assert updated_cached.access_attributes.teams is None assert updated_cached.access_attributes.teams is None
new_registry = CachedDiskDistributionRegistry(registry.kvstore) new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
await new_registry.initialize() await new_registry.initialize()
new_model = await new_registry.get("model", "model-acl") new_model = await new_registry.get("model", "model-acl")
@ -81,7 +58,7 @@ async def test_registry_cache_with_acl(registry):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry_empty_acl(registry): async def test_registry_empty_acl(cached_disk_dist_registry):
model = ModelWithACL( model = ModelWithACL(
identifier="model-empty-acl", identifier="model-empty-acl",
provider_id="test-provider", provider_id="test-provider",
@ -90,9 +67,9 @@ async def test_registry_empty_acl(registry):
access_attributes=AccessAttributes(), access_attributes=AccessAttributes(),
) )
await registry.register(model) await cached_disk_dist_registry.register(model)
cached_model = registry.get_cached("model", "model-empty-acl") cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
assert cached_model is not None assert cached_model is not None
assert cached_model.access_attributes is not None assert cached_model.access_attributes is not None
assert cached_model.access_attributes.roles is None assert cached_model.access_attributes.roles is None
@ -100,7 +77,7 @@ async def test_registry_empty_acl(registry):
assert cached_model.access_attributes.projects is None assert cached_model.access_attributes.projects is None
assert cached_model.access_attributes.namespaces is None assert cached_model.access_attributes.namespaces is None
all_models = await registry.get_all() all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 1 assert len(all_models) == 1
model = ModelWithACL( model = ModelWithACL(
@ -110,18 +87,18 @@ async def test_registry_empty_acl(registry):
model_type=ModelType.llm, model_type=ModelType.llm,
) )
await registry.register(model) await cached_disk_dist_registry.register(model)
cached_model = registry.get_cached("model", "model-no-acl") cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
assert cached_model is not None assert cached_model is not None
assert cached_model.access_attributes is None assert cached_model.access_attributes is None
all_models = await registry.get_all() all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 2 assert len(all_models) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry_serialization(registry): async def test_registry_serialization(cached_disk_dist_registry):
attributes = AccessAttributes( attributes = AccessAttributes(
roles=["admin", "researcher"], roles=["admin", "researcher"],
teams=["ai-team", "ml-team"], teams=["ai-team", "ml-team"],
@ -137,9 +114,9 @@ async def test_registry_serialization(registry):
access_attributes=attributes, access_attributes=attributes,
) )
await registry.register(model) await cached_disk_dist_registry.register(model)
new_registry = CachedDiskDistributionRegistry(registry.kvstore) new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
await new_registry.initialize() await new_registry.initialize()
loaded_model = await new_registry.get("model", "model-serialize") loaded_model = await new_registry.get("model", "model-serialize")

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
import shutil
import tempfile
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
@ -15,9 +12,6 @@ from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
class AsyncMock(MagicMock): class AsyncMock(MagicMock):
@ -30,25 +24,16 @@ def _return_model(model):
@pytest.fixture @pytest.fixture
async def test_setup(): async def test_setup(cached_disk_dist_registry):
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_access_control.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
mock_inference = Mock() mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock() mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model) mock_inference.register_model = AsyncMock(side_effect=_return_model)
routing_table = ModelsRoutingTable( routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference}, impls_by_provider_id={"test_provider": mock_inference},
dist_registry=registry, dist_registry=cached_disk_dist_registry,
) )
yield registry, routing_table yield cached_disk_dist_registry, routing_table
shutil.rmtree(temp_dir)
@pytest.mark.asyncio @pytest.mark.asyncio