Revert "fix: Added a bug fix when registering new models (#3453)"

This reverts commit e0e2b1bd0e.
This commit is contained in:
Matthew Farrellee 2025-09-17 10:36:29 -04:00 committed by GitHub
parent 9acf49753e
commit 506b8ed744
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 7 additions and 14 deletions

View file

@ -96,11 +96,9 @@ class DiskDistributionRegistry(DistributionRegistry):
async def register(self, obj: RoutableObjectWithProvider) -> bool: async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_obj = await self.get(obj.type, obj.identifier) existing_obj = await self.get(obj.type, obj.identifier)
# warn if the object's providerid is different but proceed with registration # dont register if the object's providerid already exists
if existing_obj and existing_obj.provider_id != obj.provider_id: if existing_obj and existing_obj.provider_id == obj.provider_id:
logger.warning( return False
f"Object {existing_obj.type}:{existing_obj.identifier}'s {existing_obj.provider_id} provider is being replaced with {obj.provider_id}"
)
await self.kvstore.set( await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),

View file

@ -129,7 +129,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result is not None assert result is not None
assert result.embedding_model == duplicate_vector_db.embedding_model # Original values preserved assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
async def test_get_all_objects(cached_disk_dist_registry): async def test_get_all_objects(cached_disk_dist_registry):
@ -174,14 +174,10 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
) )
await sqlite_kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
valid_db.model_dump_json(),
) )
await sqlite_kvstore.set( await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"),
"{not valid json",
)
await sqlite_kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
@ -216,8 +212,7 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
) )
await sqlite_kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
valid_db.model_dump_json(),
) )
await sqlite_kvstore.set( await sqlite_kvstore.set(