diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 624dbd176..234a994f9 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -14,7 +14,10 @@ from llama_stack.core.datatypes import RoutableObjectWithProvider from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl -from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) logger = get_logger(__name__, category="core::registry") @@ -98,7 +101,10 @@ class DiskDistributionRegistry(DistributionRegistry): existing_obj = await self.get(obj.type, obj.identifier) # dont register if the object's providerid already exists if existing_obj and existing_obj.provider_id == obj.provider_id: - return False + raise ValueError( + f"Provider '{obj.provider_id}' is already registered." + f"Unregister the existing provider first before registering it again." + ) await self.kvstore.set( KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 4ea4a20b9..e17855b40 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -125,8 +125,15 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): provider_resource_id="test_vector_db_2", provider_id="baz", # Same provider_id ) - await cached_disk_dist_registry.register(duplicate_vector_db) + # Now we expect a ValueError to be raised for duplicate registration + with pytest.raises( + ValueError, + match=r"Provider 'baz' is already registered.*Unregister the existing provider first before registering it again.", + ): + await cached_disk_dist_registry.register(duplicate_vector_db) + + # Verify the original registration is still intact result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None assert result.embedding_model == original_vector_db.embedding_model # Original values preserved @@ -174,10 +181,14 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore): ) await sqlite_kvstore.set( - KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json() + KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), + valid_db.model_dump_json(), ) - await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json") + await sqlite_kvstore.set( + KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), + "{not valid json", + ) await sqlite_kvstore.set( KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), @@ -212,7 +223,8 @@ async def test_cached_registry_error_handling(sqlite_kvstore): ) await sqlite_kvstore.set( - KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json() + KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), + valid_db.model_dump_json(), ) await sqlite_kvstore.set(