diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py index 54bc04f9c..9c5b72f93 100644 --- a/llama_stack/distribution/store/tests/test_registry.py +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -15,7 +15,8 @@ from llama_stack.distribution.store.registry import ( CachedDiskDistributionRegistry, DiskDistributionRegistry, ) -from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @pytest.fixture @@ -26,14 +27,14 @@ def config(): return config -@pytest_asyncio.fixture +@pytest_asyncio.fixture(scope="function") async def registry(config): registry = DiskDistributionRegistry(await kvstore_impl(config)) await registry.initialize() return registry -@pytest_asyncio.fixture +@pytest_asyncio.fixture(scope="function") async def cached_registry(config): registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) await registry.initialize() @@ -64,8 +65,8 @@ def sample_model(): @pytest.mark.asyncio async def test_registry_initialization(registry): # Test empty registry - results = await registry.get("nonexistent", "nonexistent") - assert len(results) == 0 + result = await registry.get("nonexistent", "nonexistent") + assert result is None @pytest.mark.asyncio @@ -75,18 +76,16 @@ async def test_basic_registration(registry, sample_bank, sample_model): print(f"Registering {sample_model}") await registry.register(sample_model) print("Getting bank") - results = await registry.get("memory_bank", "test_bank") - assert len(results) == 1 - result_bank = results[0] + result_bank = await registry.get("memory_bank", "test_bank") + assert result_bank is not None assert result_bank.identifier == sample_bank.identifier assert result_bank.embedding_model == sample_bank.embedding_model assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens assert result_bank.provider_id == sample_bank.provider_id - results = await registry.get("model", "test_model") - assert len(results) == 1 - result_model = results[0] + result_model = await registry.get("model", "test_model") + assert result_model is not None assert result_model.identifier == sample_model.identifier assert result_model.provider_id == sample_model.provider_id @@ -103,9 +102,8 @@ async def test_cached_registry_initialization(config, sample_bank, sample_model) cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) await cached_registry.initialize() - results = await cached_registry.get("memory_bank", "test_bank") - assert len(results) == 1 - result_bank = results[0] + result_bank = await cached_registry.get("memory_bank", "test_bank") + assert result_bank is not None assert result_bank.identifier == sample_bank.identifier assert result_bank.embedding_model == sample_bank.embedding_model assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens @@ -129,18 +127,16 @@ async def test_cached_registry_updates(config): await cached_registry.register(new_bank) # Verify in cache - results = await cached_registry.get("memory_bank", "test_bank_2") - assert len(results) == 1 - result_bank = results[0] + result_bank = await cached_registry.get("memory_bank", "test_bank_2") + assert result_bank is not None assert result_bank.identifier == new_bank.identifier assert result_bank.provider_id == new_bank.provider_id # Verify persisted to disk new_registry = DiskDistributionRegistry(await kvstore_impl(config)) await new_registry.initialize() - results = await new_registry.get("memory_bank", "test_bank_2") - assert len(results) == 1 - result_bank = results[0] + result_bank = await new_registry.get("memory_bank", "test_bank_2") + assert result_bank is not None assert result_bank.identifier == new_bank.identifier assert result_bank.provider_id == new_bank.provider_id @@ -170,10 +166,10 @@ async def test_duplicate_provider_registration(config): ) await cached_registry.register(duplicate_bank) - results = await cached_registry.get("memory_bank", "test_bank_2") - assert len(results) == 1 # Still only one result + result = await cached_registry.get("memory_bank", "test_bank_2") + assert result is not None assert ( - results[0].embedding_model == original_bank.embedding_model + result.embedding_model == original_bank.embedding_model ) # Original values preserved