diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index a5e3fafa1..9873bec5b 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -10,9 +10,9 @@ import pytest from llama_stack.apis.inference import Model from llama_stack.apis.vector_dbs import VectorDB from llama_stack.core.store.registry import ( + KEY_FORMAT, CachedDiskDistributionRegistry, DiskDistributionRegistry, - KEY_FORMAT, ) from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @@ -62,9 +62,7 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m assert result_model.provider_id == sample_model.provider_id -async def test_cached_registry_initialization( - sqlite_kvstore, sample_vector_db, sample_model -): +async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model): # First populate the disk registry disk_registry = DiskDistributionRegistry(sqlite_kvstore) await disk_registry.initialize() @@ -73,9 +71,7 @@ async def test_cached_registry_initialization( # Test cached version loads from disk db_path = sqlite_kvstore.db_path - cached_registry = CachedDiskDistributionRegistry( - await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) - ) + cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) await cached_registry.initialize() result_vector_db = await cached_registry.get("vector_db", "test_vector_db") @@ -97,18 +93,14 @@ async def test_cached_registry_updates(cached_disk_dist_registry): await cached_disk_dist_registry.register(new_vector_db) # Verify in cache - result_vector_db = await cached_disk_dist_registry.get( - "vector_db", "test_vector_db_2" - ) + result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result_vector_db is not None assert result_vector_db.identifier == new_vector_db.identifier assert result_vector_db.provider_id == new_vector_db.provider_id # Verify persisted to disk db_path = cached_disk_dist_registry.kvstore.db_path - new_registry = DiskDistributionRegistry( - await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) - ) + new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) await new_registry.initialize() result_vector_db = await new_registry.get("vector_db", "test_vector_db_2") assert result_vector_db is not None @@ -137,9 +129,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None - assert ( - result.embedding_model == duplicate_vector_db.embedding_model - ) # Original values preserved + assert result.embedding_model == duplicate_vector_db.embedding_model # Original values preserved async def test_get_all_objects(cached_disk_dist_registry): @@ -166,17 +156,12 @@ async def test_get_all_objects(cached_disk_dist_registry): # Verify each vector_db was stored correctly for original_vector_db in test_vector_dbs: - matching_vector_dbs = [ - v for v in all_results if v.identifier == original_vector_db.identifier - ] + matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier] assert len(matching_vector_dbs) == 1 stored_vector_db = matching_vector_dbs[0] assert stored_vector_db.embedding_model == original_vector_db.embedding_model assert stored_vector_db.provider_id == original_vector_db.provider_id - assert ( - stored_vector_db.embedding_dimension - == original_vector_db.embedding_dimension - ) + assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension async def test_parse_registry_values_error_handling(sqlite_kvstore):