fix: Added a bug fix when registering new models (#3453)

# What does this PR do?

Modified the code in registry.py.

The key changes are:

1.  Removed the `return False` statement
2. Added a warning log message that includes the object type,
identifier, and provider_id for better debugging.
3. The method now continues with the registration process instead of
early returning.

---------

Co-authored-by: Omar Abdelwahab <omara@fb.com>
This commit is contained in:
Omar Abdelwahab 2025-09-16 19:09:06 -07:00 committed by GitHub
parent ececc323d3
commit e0e2b1bd0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 14 additions and 7 deletions

View file

@ -96,9 +96,11 @@ class DiskDistributionRegistry(DistributionRegistry):
async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_obj = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists
if existing_obj and existing_obj.provider_id == obj.provider_id:
return False
# warn if the object's providerid is different but proceed with registration
if existing_obj and existing_obj.provider_id != obj.provider_id:
logger.warning(
f"Object {existing_obj.type}:{existing_obj.identifier}'s {existing_obj.provider_id} provider is being replaced with {obj.provider_id}"
)
await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),

View file

@ -129,7 +129,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result is not None
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
assert result.embedding_model == duplicate_vector_db.embedding_model # Original values preserved
async def test_get_all_objects(cached_disk_dist_registry):
@ -174,10 +174,14 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"),
valid_db.model_dump_json(),
)
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"),
"{not valid json",
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
@ -212,7 +216,8 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"),
valid_db.model_dump_json(),
)
await sqlite_kvstore.set(