diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 5f4abe9aa..a764d692a 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -96,9 +96,11 @@ class DiskDistributionRegistry(DistributionRegistry): async def register(self, obj: RoutableObjectWithProvider) -> bool: existing_obj = await self.get(obj.type, obj.identifier) - # dont register if the object's providerid already exists - if existing_obj and existing_obj.provider_id == obj.provider_id: - return False + # warn if the object's providerid is different but proceed with registration + if existing_obj and existing_obj.provider_id != obj.provider_id: + logger.warning( + f"Object {existing_obj.type}:{existing_obj.identifier}'s {existing_obj.provider_id} provider is being replaced with {obj.provider_id}" + ) await self.kvstore.set( KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 4ea4a20b9..9873bec5b 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -129,7 +129,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None - assert result.embedding_model == original_vector_db.embedding_model # Original values preserved + assert result.embedding_model == duplicate_vector_db.embedding_model # Original values preserved async def test_get_all_objects(cached_disk_dist_registry): @@ -174,10 +174,14 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore): ) await sqlite_kvstore.set( - KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json() + KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), + valid_db.model_dump_json(), ) - await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json") + await sqlite_kvstore.set( + KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), + "{not valid json", + ) await sqlite_kvstore.set( KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), @@ -212,7 +216,8 @@ async def test_cached_registry_error_handling(sqlite_kvstore): ) await sqlite_kvstore.set( - KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json() + KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), + valid_db.model_dump_json(), ) await sqlite_kvstore.set(