Fix broken tests in test_registry (#707)

Summary:
Tests were broken after registry.get return type was changed from
`List[RoutableObjectWithProvider]` to
`Optional[RoutableObjectWithProvider]` in
efe791bab7 (diff-5de152bae521b7baef01048a4c0142484f8f1c978a04f3b55f4e4dabc52835beL29)

Test Plan:
Run tests
```
pytest llama_stack/distribution/store/tests/test_registry.py -v

collected 6 items

llama_stack/distribution/store/tests/test_registry.py::test_registry_initialization PASSED                                                                  [ 16%]
llama_stack/distribution/store/tests/test_registry.py::test_basic_registration PASSED                                                                       [ 33%]
llama_stack/distribution/store/tests/test_registry.py::test_cached_registry_initialization PASSED                                                           [ 50%]
llama_stack/distribution/store/tests/test_registry.py::test_cached_registry_updates PASSED                                                                  [ 66%]
llama_stack/distribution/store/tests/test_registry.py::test_duplicate_provider_registration PASSED                                                          [ 83%]
llama_stack/distribution/store/tests/test_registry.py::test_get_all_objects PASSED                                                                          [100%]
```
This commit is contained in:
Vladimir Ivić 2025-01-14 14:33:15 -08:00 committed by GitHub
parent 91907b714e
commit 472feea8d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -15,7 +15,8 @@ from llama_stack.distribution.store.registry import (
CachedDiskDistributionRegistry, CachedDiskDistributionRegistry,
DiskDistributionRegistry, DiskDistributionRegistry,
) )
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture @pytest.fixture
@ -26,14 +27,14 @@ def config():
return config return config
@pytest_asyncio.fixture @pytest_asyncio.fixture(scope="function")
async def registry(config): async def registry(config):
registry = DiskDistributionRegistry(await kvstore_impl(config)) registry = DiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize() await registry.initialize()
return registry return registry
@pytest_asyncio.fixture @pytest_asyncio.fixture(scope="function")
async def cached_registry(config): async def cached_registry(config):
registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await registry.initialize() await registry.initialize()
@ -64,8 +65,8 @@ def sample_model():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry_initialization(registry): async def test_registry_initialization(registry):
# Test empty registry # Test empty registry
results = await registry.get("nonexistent", "nonexistent") result = await registry.get("nonexistent", "nonexistent")
assert len(results) == 0 assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@ -75,18 +76,16 @@ async def test_basic_registration(registry, sample_bank, sample_model):
print(f"Registering {sample_model}") print(f"Registering {sample_model}")
await registry.register(sample_model) await registry.register(sample_model)
print("Getting bank") print("Getting bank")
results = await registry.get("memory_bank", "test_bank") result_bank = await registry.get("memory_bank", "test_bank")
assert len(results) == 1 assert result_bank is not None
result_bank = results[0]
assert result_bank.identifier == sample_bank.identifier assert result_bank.identifier == sample_bank.identifier
assert result_bank.embedding_model == sample_bank.embedding_model assert result_bank.embedding_model == sample_bank.embedding_model
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
assert result_bank.provider_id == sample_bank.provider_id assert result_bank.provider_id == sample_bank.provider_id
results = await registry.get("model", "test_model") result_model = await registry.get("model", "test_model")
assert len(results) == 1 assert result_model is not None
result_model = results[0]
assert result_model.identifier == sample_model.identifier assert result_model.identifier == sample_model.identifier
assert result_model.provider_id == sample_model.provider_id assert result_model.provider_id == sample_model.provider_id
@ -103,9 +102,8 @@ async def test_cached_registry_initialization(config, sample_bank, sample_model)
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
await cached_registry.initialize() await cached_registry.initialize()
results = await cached_registry.get("memory_bank", "test_bank") result_bank = await cached_registry.get("memory_bank", "test_bank")
assert len(results) == 1 assert result_bank is not None
result_bank = results[0]
assert result_bank.identifier == sample_bank.identifier assert result_bank.identifier == sample_bank.identifier
assert result_bank.embedding_model == sample_bank.embedding_model assert result_bank.embedding_model == sample_bank.embedding_model
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
@ -129,18 +127,16 @@ async def test_cached_registry_updates(config):
await cached_registry.register(new_bank) await cached_registry.register(new_bank)
# Verify in cache # Verify in cache
results = await cached_registry.get("memory_bank", "test_bank_2") result_bank = await cached_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1 assert result_bank is not None
result_bank = results[0]
assert result_bank.identifier == new_bank.identifier assert result_bank.identifier == new_bank.identifier
assert result_bank.provider_id == new_bank.provider_id assert result_bank.provider_id == new_bank.provider_id
# Verify persisted to disk # Verify persisted to disk
new_registry = DiskDistributionRegistry(await kvstore_impl(config)) new_registry = DiskDistributionRegistry(await kvstore_impl(config))
await new_registry.initialize() await new_registry.initialize()
results = await new_registry.get("memory_bank", "test_bank_2") result_bank = await new_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1 assert result_bank is not None
result_bank = results[0]
assert result_bank.identifier == new_bank.identifier assert result_bank.identifier == new_bank.identifier
assert result_bank.provider_id == new_bank.provider_id assert result_bank.provider_id == new_bank.provider_id
@ -170,10 +166,10 @@ async def test_duplicate_provider_registration(config):
) )
await cached_registry.register(duplicate_bank) await cached_registry.register(duplicate_bank)
results = await cached_registry.get("memory_bank", "test_bank_2") result = await cached_registry.get("memory_bank", "test_bank_2")
assert len(results) == 1 # Still only one result assert result is not None
assert ( assert (
results[0].embedding_model == original_bank.embedding_model result.embedding_model == original_bank.embedding_model
) # Original values preserved ) # Original values preserved