From 87f1b6994c674ae0efcb460fe98b55befd78f773 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Mon, 15 Sep 2025 18:21:53 -0700 Subject: [PATCH] Fix duplicate comment in test_registry.py --- tests/unit/registry/test_registry.py | 42 ++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 4ea4a20b9..71ecf859f 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -10,9 +10,9 @@ import pytest from llama_stack.apis.inference import Model from llama_stack.apis.vector_dbs import VectorDB from llama_stack.core.store.registry import ( - KEY_FORMAT, CachedDiskDistributionRegistry, DiskDistributionRegistry, + KEY_FORMAT, ) from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @@ -62,7 +62,9 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m assert result_model.provider_id == sample_model.provider_id -async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model): +async def test_cached_registry_initialization( + sqlite_kvstore, sample_vector_db, sample_model +): # First populate the disk registry disk_registry = DiskDistributionRegistry(sqlite_kvstore) await disk_registry.initialize() @@ -71,7 +73,9 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, # Test cached version loads from disk db_path = sqlite_kvstore.db_path - cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) + cached_registry = CachedDiskDistributionRegistry( + await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) + ) await cached_registry.initialize() result_vector_db = await cached_registry.get("vector_db", "test_vector_db") @@ -93,14 +97,18 @@ async def test_cached_registry_updates(cached_disk_dist_registry): await cached_disk_dist_registry.register(new_vector_db) # Verify in cache - result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") + result_vector_db = await cached_disk_dist_registry.get( + "vector_db", "test_vector_db_2" + ) assert result_vector_db is not None assert result_vector_db.identifier == new_vector_db.identifier assert result_vector_db.provider_id == new_vector_db.provider_id # Verify persisted to disk db_path = cached_disk_dist_registry.kvstore.db_path - new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) + new_registry = DiskDistributionRegistry( + await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) + ) await new_registry.initialize() result_vector_db = await new_registry.get("vector_db", "test_vector_db_2") assert result_vector_db is not None @@ -129,7 +137,9 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None - assert result.embedding_model == original_vector_db.embedding_model # Original values preserved + assert ( + result.embedding_model == duplicate_vector_db.embedding_model + ) # Warning is logged, but we still return the latest async def test_get_all_objects(cached_disk_dist_registry): @@ -156,12 +166,17 @@ async def test_get_all_objects(cached_disk_dist_registry): # Verify each vector_db was stored correctly for original_vector_db in test_vector_dbs: - matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier] + matching_vector_dbs = [ + v for v in all_results if v.identifier == original_vector_db.identifier + ] assert len(matching_vector_dbs) == 1 stored_vector_db = matching_vector_dbs[0] assert stored_vector_db.embedding_model == original_vector_db.embedding_model assert stored_vector_db.provider_id == original_vector_db.provider_id - assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension + assert ( + stored_vector_db.embedding_dimension + == original_vector_db.embedding_dimension + ) async def test_parse_registry_values_error_handling(sqlite_kvstore): @@ -174,10 +189,14 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore): ) await sqlite_kvstore.set( - KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json() + KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), + valid_db.model_dump_json(), ) - await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json") + await sqlite_kvstore.set( + KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), + "{not valid json", + ) await sqlite_kvstore.set( KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), @@ -212,7 +231,8 @@ async def test_cached_registry_error_handling(sqlite_kvstore): ) await sqlite_kvstore.set( - KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json() + KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), + valid_db.model_dump_json(), ) await sqlite_kvstore.set(