From 2e807b38cc449ecb61a23281986081b3130f6033 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Tue, 6 May 2025 12:57:48 +0100 Subject: [PATCH] chore: Add fixtures to conftest.py (#2067) Add fixtures for SqliteKVStore, DiskDistributionRegistry and CachedDiskDistributionRegistry. And use them in tests that had all been duplicating similar setups. ## Test Plan unit tests continue to run Signed-off-by: Derek Higgins --- tests/unit/__init__.py | 5 + tests/unit/conftest.py | 9 ++ .../routers/test_routing_tables.py | 44 +++----- tests/unit/fixtures.py | 34 ++++++ .../agents/test_persistence_access_control.py | 15 +-- tests/unit/registry/test_registry.py | 104 +++++++----------- tests/unit/registry/test_registry_acl.py | 57 +++------- tests/unit/server/test_access_control.py | 21 +--- 8 files changed, 122 insertions(+), 167 deletions(-) create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/fixtures.py diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..aedac0386 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# We need to import the fixtures here so that pytest can find them +# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed +from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401 diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 305e53839..4e6585ad6 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -26,20 +26,6 @@ from llama_stack.distribution.routers.routing_tables import ( ToolGroupsRoutingTable, VectorDBsRoutingTable, ) -from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl - - -@pytest.fixture -async def dist_registry(tmp_path): - db_path = tmp_path / "test_kv.db" - kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix()) - kvstore = SqliteKVStoreImpl(kvstore_config) - await kvstore.initialize() - registry = CachedDiskDistributionRegistry(kvstore) - await registry.initialize() - yield registry class Impl: @@ -136,8 +122,8 @@ class ToolGroupsImpl(Impl): @pytest.mark.asyncio -async def test_models_routing_table(dist_registry): - table = ModelsRoutingTable({"test_provider": InferenceImpl()}, dist_registry) +async def test_models_routing_table(cached_disk_dist_registry): + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry) await table.initialize() # Register multiple models and verify listing @@ -178,8 +164,8 @@ async def test_models_routing_table(dist_registry): @pytest.mark.asyncio -async def test_shields_routing_table(dist_registry): - table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, dist_registry) +async def test_shields_routing_table(cached_disk_dist_registry): + table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry) await table.initialize() # Register multiple shields and verify listing @@ -194,11 +180,11 @@ async def test_shields_routing_table(dist_registry): @pytest.mark.asyncio -async def test_vectordbs_routing_table(dist_registry): - table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, dist_registry) +async def test_vectordbs_routing_table(cached_disk_dist_registry): + table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry) await table.initialize() - m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, dist_registry) + m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry) await m_table.initialize() await m_table.register_model( model_id="test-model", @@ -224,8 +210,8 @@ async def test_vectordbs_routing_table(dist_registry): assert len(vector_dbs.data) == 0 -async def test_datasets_routing_table(dist_registry): - table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, dist_registry) +async def test_datasets_routing_table(cached_disk_dist_registry): + table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry) await table.initialize() # Register multiple datasets and verify listing @@ -250,8 +236,8 @@ async def test_datasets_routing_table(dist_registry): @pytest.mark.asyncio -async def test_scoring_functions_routing_table(dist_registry): - table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, dist_registry) +async def test_scoring_functions_routing_table(cached_disk_dist_registry): + table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry) await table.initialize() # Register multiple scoring functions and verify listing @@ -276,8 +262,8 @@ async def test_scoring_functions_routing_table(dist_registry): @pytest.mark.asyncio -async def test_benchmarks_routing_table(dist_registry): - table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, dist_registry) +async def test_benchmarks_routing_table(cached_disk_dist_registry): + table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry) await table.initialize() # Register multiple benchmarks and verify listing @@ -294,8 +280,8 @@ async def test_benchmarks_routing_table(dist_registry): @pytest.mark.asyncio -async def test_tool_groups_routing_table(dist_registry): - table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, dist_registry) +async def test_tool_groups_routing_table(cached_disk_dist_registry): + table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry) await table.initialize() # Register multiple tool groups and verify listing diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py new file mode 100644 index 000000000..7174d2e78 --- /dev/null +++ b/tests/unit/fixtures.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl + + +@pytest.fixture(scope="function") +async def sqlite_kvstore(tmp_path): + db_path = tmp_path / "test_kv.db" + kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix()) + kvstore = SqliteKVStoreImpl(kvstore_config) + await kvstore.initialize() + yield kvstore + + +@pytest.fixture(scope="function") +async def disk_dist_registry(sqlite_kvstore): + registry = DiskDistributionRegistry(sqlite_kvstore) + await registry.initialize() + yield registry + + +@pytest.fixture(scope="function") +async def cached_disk_dist_registry(sqlite_kvstore): + registry = CachedDiskDistributionRegistry(sqlite_kvstore) + await registry.initialize() + yield registry diff --git a/tests/unit/providers/agents/test_persistence_access_control.py b/tests/unit/providers/agents/test_persistence_access_control.py index ab181a4ae..06e1a778a 100644 --- a/tests/unit/providers/agents/test_persistence_access_control.py +++ b/tests/unit/providers/agents/test_persistence_access_control.py @@ -4,9 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os -import shutil -import tempfile import uuid from datetime import datetime from unittest.mock import patch @@ -17,20 +14,12 @@ from llama_stack.apis.agents import Turn from llama_stack.apis.inference import CompletionMessage, StopReason from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl @pytest.fixture -async def test_setup(): - temp_dir = tempfile.mkdtemp() - db_path = os.path.join(temp_dir, "test_persistence_access_control.db") - kvstore_config = SqliteKVStoreConfig(db_path=db_path) - kvstore = SqliteKVStoreImpl(kvstore_config) - await kvstore.initialize() - agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=kvstore) +async def test_setup(sqlite_kvstore): + agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore) yield agent_persistence - shutil.rmtree(temp_dir) @pytest.mark.asyncio diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 9896b3212..909581bb7 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -4,10 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os import pytest -import pytest_asyncio from llama_stack.apis.inference import Model from llama_stack.apis.vector_dbs import VectorDB @@ -20,28 +18,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -@pytest.fixture -def config(): - config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db") - if os.path.exists(config.db_path): - os.remove(config.db_path) - return config - - -@pytest_asyncio.fixture(scope="function") -async def registry(config): - registry = DiskDistributionRegistry(await kvstore_impl(config)) - await registry.initialize() - return registry - - -@pytest_asyncio.fixture(scope="function") -async def cached_registry(config): - registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) - await registry.initialize() - return registry - - @pytest.fixture def sample_vector_db(): return VectorDB( @@ -63,41 +39,42 @@ def sample_model(): @pytest.mark.asyncio -async def test_registry_initialization(registry): +async def test_registry_initialization(disk_dist_registry): # Test empty registry - result = await registry.get("nonexistent", "nonexistent") + result = await disk_dist_registry.get("nonexistent", "nonexistent") assert result is None @pytest.mark.asyncio -async def test_basic_registration(registry, sample_vector_db, sample_model): +async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model): print(f"Registering {sample_vector_db}") - await registry.register(sample_vector_db) + await disk_dist_registry.register(sample_vector_db) print(f"Registering {sample_model}") - await registry.register(sample_model) + await disk_dist_registry.register(sample_model) print("Getting vector_db") - result_vector_db = await registry.get("vector_db", "test_vector_db") + result_vector_db = await disk_dist_registry.get("vector_db", "test_vector_db") assert result_vector_db is not None assert result_vector_db.identifier == sample_vector_db.identifier assert result_vector_db.embedding_model == sample_vector_db.embedding_model assert result_vector_db.provider_id == sample_vector_db.provider_id - result_model = await registry.get("model", "test_model") + result_model = await disk_dist_registry.get("model", "test_model") assert result_model is not None assert result_model.identifier == sample_model.identifier assert result_model.provider_id == sample_model.provider_id @pytest.mark.asyncio -async def test_cached_registry_initialization(config, sample_vector_db, sample_model): +async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model): # First populate the disk registry - disk_registry = DiskDistributionRegistry(await kvstore_impl(config)) + disk_registry = DiskDistributionRegistry(sqlite_kvstore) await disk_registry.initialize() await disk_registry.register(sample_vector_db) await disk_registry.register(sample_model) # Test cached version loads from disk - cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) + db_path = sqlite_kvstore.db_path + cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) await cached_registry.initialize() result_vector_db = await cached_registry.get("vector_db", "test_vector_db") @@ -109,10 +86,7 @@ async def test_cached_registry_initialization(config, sample_vector_db, sample_m @pytest.mark.asyncio -async def test_cached_registry_updates(config): - cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) - await cached_registry.initialize() - +async def test_cached_registry_updates(cached_disk_dist_registry): new_vector_db = VectorDB( identifier="test_vector_db_2", embedding_model="all-MiniLM-L6-v2", @@ -120,16 +94,17 @@ async def test_cached_registry_updates(config): provider_resource_id="test_vector_db_2", provider_id="baz", ) - await cached_registry.register(new_vector_db) + await cached_disk_dist_registry.register(new_vector_db) # Verify in cache - result_vector_db = await cached_registry.get("vector_db", "test_vector_db_2") + result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result_vector_db is not None assert result_vector_db.identifier == new_vector_db.identifier assert result_vector_db.provider_id == new_vector_db.provider_id # Verify persisted to disk - new_registry = DiskDistributionRegistry(await kvstore_impl(config)) + db_path = cached_disk_dist_registry.kvstore.db_path + new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) await new_registry.initialize() result_vector_db = await new_registry.get("vector_db", "test_vector_db_2") assert result_vector_db is not None @@ -138,10 +113,7 @@ async def test_cached_registry_updates(config): @pytest.mark.asyncio -async def test_duplicate_provider_registration(config): - cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) - await cached_registry.initialize() - +async def test_duplicate_provider_registration(cached_disk_dist_registry): original_vector_db = VectorDB( identifier="test_vector_db_2", embedding_model="all-MiniLM-L6-v2", @@ -149,7 +121,7 @@ async def test_duplicate_provider_registration(config): provider_resource_id="test_vector_db_2", provider_id="baz", ) - await cached_registry.register(original_vector_db) + await cached_disk_dist_registry.register(original_vector_db) duplicate_vector_db = VectorDB( identifier="test_vector_db_2", @@ -158,18 +130,16 @@ async def test_duplicate_provider_registration(config): provider_resource_id="test_vector_db_2", provider_id="baz", # Same provider_id ) - await cached_registry.register(duplicate_vector_db) + await cached_disk_dist_registry.register(duplicate_vector_db) - result = await cached_registry.get("vector_db", "test_vector_db_2") + result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None assert result.embedding_model == original_vector_db.embedding_model # Original values preserved @pytest.mark.asyncio -async def test_get_all_objects(config): - cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) - await cached_registry.initialize() - +async def test_get_all_objects(cached_disk_dist_registry): + # Create multiple test banks # Create multiple test banks test_vector_dbs = [ VectorDB( @@ -184,10 +154,10 @@ async def test_get_all_objects(config): # Register all vector_dbs for vector_db in test_vector_dbs: - await cached_registry.register(vector_db) + await cached_disk_dist_registry.register(vector_db) # Test get_all retrieval - all_results = await cached_registry.get_all() + all_results = await cached_disk_dist_registry.get_all() assert len(all_results) == 3 # Verify each vector_db was stored correctly @@ -201,9 +171,7 @@ async def test_get_all_objects(config): @pytest.mark.asyncio -async def test_parse_registry_values_error_handling(config): - kvstore = await kvstore_impl(config) - +async def test_parse_registry_values_error_handling(sqlite_kvstore): valid_db = VectorDB( identifier="valid_vector_db", embedding_model="all-MiniLM-L6-v2", @@ -212,16 +180,18 @@ async def test_parse_registry_values_error_handling(config): provider_id="test-provider", ) - await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()) + await sqlite_kvstore.set( + KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json() + ) - await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json") + await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json") - await kvstore.set( + await sqlite_kvstore.set( KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), '{"type": "vector_db", "identifier": "missing_fields"}', ) - test_registry = DiskDistributionRegistry(kvstore) + test_registry = DiskDistributionRegistry(sqlite_kvstore) await test_registry.initialize() # Get all objects, which should only return the valid one @@ -240,9 +210,7 @@ async def test_parse_registry_values_error_handling(config): @pytest.mark.asyncio -async def test_cached_registry_error_handling(config): - kvstore = await kvstore_impl(config) - +async def test_cached_registry_error_handling(sqlite_kvstore): valid_db = VectorDB( identifier="valid_cached_db", embedding_model="all-MiniLM-L6-v2", @@ -251,14 +219,16 @@ async def test_cached_registry_error_handling(config): provider_id="test-provider", ) - await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()) + await sqlite_kvstore.set( + KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json() + ) - await kvstore.set( + await sqlite_kvstore.set( KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"), '{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string ) - cached_registry = CachedDiskDistributionRegistry(kvstore) + cached_registry = CachedDiskDistributionRegistry(sqlite_kvstore) await cached_registry.initialize() all_objects = await cached_registry.get_all() diff --git a/tests/unit/registry/test_registry_acl.py b/tests/unit/registry/test_registry_acl.py index 2a50b2840..25ea37bfa 100644 --- a/tests/unit/registry/test_registry_acl.py +++ b/tests/unit/registry/test_registry_acl.py @@ -4,9 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os -import shutil -import tempfile import pytest @@ -14,30 +11,10 @@ from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import ModelWithACL from llama_stack.distribution.server.auth_providers import AccessAttributes from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl - - -@pytest.fixture(scope="function") -async def kvstore(): - temp_dir = tempfile.mkdtemp() - db_path = os.path.join(temp_dir, "test_registry_acl.db") - kvstore_config = SqliteKVStoreConfig(db_path=db_path) - kvstore = SqliteKVStoreImpl(kvstore_config) - await kvstore.initialize() - yield kvstore - shutil.rmtree(temp_dir) - - -@pytest.fixture(scope="function") -async def registry(kvstore): - registry = CachedDiskDistributionRegistry(kvstore) - await registry.initialize() - return registry @pytest.mark.asyncio -async def test_registry_cache_with_acl(registry): +async def test_registry_cache_with_acl(cached_disk_dist_registry): model = ModelWithACL( identifier="model-acl", provider_id="test-provider", @@ -46,30 +23,30 @@ async def test_registry_cache_with_acl(registry): access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]), ) - success = await registry.register(model) + success = await cached_disk_dist_registry.register(model) assert success - cached_model = registry.get_cached("model", "model-acl") + cached_model = cached_disk_dist_registry.get_cached("model", "model-acl") assert cached_model is not None assert cached_model.identifier == "model-acl" assert cached_model.access_attributes.roles == ["admin"] assert cached_model.access_attributes.teams == ["ai-team"] - fetched_model = await registry.get("model", "model-acl") + fetched_model = await cached_disk_dist_registry.get("model", "model-acl") assert fetched_model is not None assert fetched_model.identifier == "model-acl" assert fetched_model.access_attributes.roles == ["admin"] model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"]) - await registry.update(model) + await cached_disk_dist_registry.update(model) - updated_cached = registry.get_cached("model", "model-acl") + updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl") assert updated_cached is not None assert updated_cached.access_attributes.roles == ["admin", "user"] assert updated_cached.access_attributes.projects == ["project-x"] assert updated_cached.access_attributes.teams is None - new_registry = CachedDiskDistributionRegistry(registry.kvstore) + new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore) await new_registry.initialize() new_model = await new_registry.get("model", "model-acl") @@ -81,7 +58,7 @@ async def test_registry_cache_with_acl(registry): @pytest.mark.asyncio -async def test_registry_empty_acl(registry): +async def test_registry_empty_acl(cached_disk_dist_registry): model = ModelWithACL( identifier="model-empty-acl", provider_id="test-provider", @@ -90,9 +67,9 @@ async def test_registry_empty_acl(registry): access_attributes=AccessAttributes(), ) - await registry.register(model) + await cached_disk_dist_registry.register(model) - cached_model = registry.get_cached("model", "model-empty-acl") + cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl") assert cached_model is not None assert cached_model.access_attributes is not None assert cached_model.access_attributes.roles is None @@ -100,7 +77,7 @@ async def test_registry_empty_acl(registry): assert cached_model.access_attributes.projects is None assert cached_model.access_attributes.namespaces is None - all_models = await registry.get_all() + all_models = await cached_disk_dist_registry.get_all() assert len(all_models) == 1 model = ModelWithACL( @@ -110,18 +87,18 @@ async def test_registry_empty_acl(registry): model_type=ModelType.llm, ) - await registry.register(model) + await cached_disk_dist_registry.register(model) - cached_model = registry.get_cached("model", "model-no-acl") + cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl") assert cached_model is not None assert cached_model.access_attributes is None - all_models = await registry.get_all() + all_models = await cached_disk_dist_registry.get_all() assert len(all_models) == 2 @pytest.mark.asyncio -async def test_registry_serialization(registry): +async def test_registry_serialization(cached_disk_dist_registry): attributes = AccessAttributes( roles=["admin", "researcher"], teams=["ai-team", "ml-team"], @@ -137,9 +114,9 @@ async def test_registry_serialization(registry): access_attributes=attributes, ) - await registry.register(model) + await cached_disk_dist_registry.register(model) - new_registry = CachedDiskDistributionRegistry(registry.kvstore) + new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore) await new_registry.initialize() loaded_model = await new_registry.get("model", "model-serialize") diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index 7d92a5cf5..b5e9c2698 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -4,9 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os -import shutil -import tempfile from unittest.mock import MagicMock, Mock, patch import pytest @@ -15,9 +12,6 @@ from llama_stack.apis.datatypes import Api from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable -from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl class AsyncMock(MagicMock): @@ -30,25 +24,16 @@ def _return_model(model): @pytest.fixture -async def test_setup(): - temp_dir = tempfile.mkdtemp() - db_path = os.path.join(temp_dir, "test_access_control.db") - kvstore_config = SqliteKVStoreConfig(db_path=db_path) - kvstore = SqliteKVStoreImpl(kvstore_config) - await kvstore.initialize() - registry = CachedDiskDistributionRegistry(kvstore) - await registry.initialize() - +async def test_setup(cached_disk_dist_registry): mock_inference = Mock() mock_inference.__provider_spec__ = MagicMock() mock_inference.__provider_spec__.api = Api.inference mock_inference.register_model = AsyncMock(side_effect=_return_model) routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, - dist_registry=registry, + dist_registry=cached_disk_dist_registry, ) - yield registry, routing_table - shutil.rmtree(temp_dir) + yield cached_disk_dist_registry, routing_table @pytest.mark.asyncio