From f0545a3fec94758f0e2bd1e07f48cd1605c497ae Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Tue, 30 Sep 2025 14:59:32 -0700 Subject: [PATCH 1/4] Added an update to the registration function to reject a change for the same provider with a message asking the caller to unregister the model first --- llama_stack/core/store/registry.py | 5 ++++- tests/unit/registry/test_registry.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 5f4abe9aa..5491ab0f1 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -98,7 +98,10 @@ class DiskDistributionRegistry(DistributionRegistry): existing_obj = await self.get(obj.type, obj.identifier) # dont register if the object's providerid already exists if existing_obj and existing_obj.provider_id == obj.provider_id: - return False + raise ValueError( + f"Model '{obj.identifier}' is already registered with provider '{obj.provider_id}'. " + f"Please unregister the existing model first before registering a new one." + ) await self.kvstore.set( KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 4ea4a20b9..5d8add4bf 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -125,8 +125,12 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): provider_resource_id="test_vector_db_2", provider_id="baz", # Same provider_id ) - await cached_disk_dist_registry.register(duplicate_vector_db) + # Now we expect a ValueError to be raised for duplicate registration + with pytest.raises(ValueError, match=r".*already registered.*"): + await cached_disk_dist_registry.register(duplicate_vector_db) + + # Verify the original registration is still intact result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None assert result.embedding_model == original_vector_db.embedding_model # Original values preserved From 5943ce404be12bc7448ca148b3ce2cb92344fcd2 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Tue, 30 Sep 2025 17:03:46 -0700 Subject: [PATCH 2/4] updated the error message --- llama_stack/core/store/registry.py | 49 ++++++++++++++++++++-------- tests/unit/registry/test_registry.py | 46 +++++++++++++++++++------- 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 5491ab0f1..9f4d774e8 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -14,7 +14,10 @@ from llama_stack.core.datatypes import RoutableObjectWithProvider from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl -from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) logger = get_logger(__name__, category="core::registry") @@ -28,7 +31,9 @@ class DistributionRegistry(Protocol): def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ... - async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ... + async def update( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: ... async def register(self, obj: RoutableObjectWithProvider) -> bool: ... @@ -54,7 +59,9 @@ def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value) all_objects.append(obj) except pydantic.ValidationError as e: - logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}") + logger.error( + f"Error parsing registry value, raw value: {value}. Error: {e}" + ) continue return all_objects @@ -67,7 +74,9 @@ class DiskDistributionRegistry(DistributionRegistry): async def initialize(self) -> None: pass - def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + def get_cached( + self, type: str, identifier: str + ) -> RoutableObjectWithProvider | None: # Disk registry does not have a cache raise NotImplementedError("Disk registry does not have a cache") @@ -76,15 +85,23 @@ class DiskDistributionRegistry(DistributionRegistry): values = await self.kvstore.values_in_range(start_key, end_key) return _parse_registry_values(values) - async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: - json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier)) + async def get( + self, type: str, identifier: str + ) -> RoutableObjectWithProvider | None: + json_str = await self.kvstore.get( + KEY_FORMAT.format(type=type, identifier=identifier) + ) if not json_str: return None try: - return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str) + return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json( + json_str + ) except pydantic.ValidationError as e: - logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}") + logger.error( + f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}" + ) return None async def update(self, obj: RoutableObjectWithProvider) -> None: @@ -99,8 +116,8 @@ class DiskDistributionRegistry(DistributionRegistry): # dont register if the object's providerid already exists if existing_obj and existing_obj.provider_id == obj.provider_id: raise ValueError( - f"Model '{obj.identifier}' is already registered with provider '{obj.provider_id}'. " - f"Please unregister the existing model first before registering a new one." + f"{obj.type.title()} '{obj.identifier}' is already registered with provider '{obj.provider_id}'. " + f"Please unregister the existing {obj.type} first before registering a new one." ) await self.kvstore.set( @@ -150,7 +167,9 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async def initialize(self) -> None: await self._ensure_initialized() - def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + def get_cached( + self, type: str, identifier: str + ) -> RoutableObjectWithProvider | None: return self.cache.get((type, identifier), None) async def get_all(self) -> list[RoutableObjectWithProvider]: @@ -158,7 +177,9 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async with self._locked_cache() as cache: return list(cache.values()) - async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + async def get( + self, type: str, identifier: str + ) -> RoutableObjectWithProvider | None: await self._ensure_initialized() cache_key = (type, identifier) @@ -200,7 +221,9 @@ async def create_dist_registry( dist_kvstore = await kvstore_impl(metadata_store) else: dist_kvstore = await kvstore_impl( - SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()) + SqliteKVStoreConfig( + db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix() + ) ) dist_registry = CachedDiskDistributionRegistry(dist_kvstore) await dist_registry.initialize() diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 5d8add4bf..83dcf0968 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -10,9 +10,9 @@ import pytest from llama_stack.apis.inference import Model from llama_stack.apis.vector_dbs import VectorDB from llama_stack.core.store.registry import ( - KEY_FORMAT, CachedDiskDistributionRegistry, DiskDistributionRegistry, + KEY_FORMAT, ) from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @@ -62,7 +62,9 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m assert result_model.provider_id == sample_model.provider_id -async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model): +async def test_cached_registry_initialization( + sqlite_kvstore, sample_vector_db, sample_model +): # First populate the disk registry disk_registry = DiskDistributionRegistry(sqlite_kvstore) await disk_registry.initialize() @@ -71,7 +73,9 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, # Test cached version loads from disk db_path = sqlite_kvstore.db_path - cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) + cached_registry = CachedDiskDistributionRegistry( + await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) + ) await cached_registry.initialize() result_vector_db = await cached_registry.get("vector_db", "test_vector_db") @@ -93,14 +97,18 @@ async def test_cached_registry_updates(cached_disk_dist_registry): await cached_disk_dist_registry.register(new_vector_db) # Verify in cache - result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") + result_vector_db = await cached_disk_dist_registry.get( + "vector_db", "test_vector_db_2" + ) assert result_vector_db is not None assert result_vector_db.identifier == new_vector_db.identifier assert result_vector_db.provider_id == new_vector_db.provider_id # Verify persisted to disk db_path = cached_disk_dist_registry.kvstore.db_path - new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) + new_registry = DiskDistributionRegistry( + await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) + ) await new_registry.initialize() result_vector_db = await new_registry.get("vector_db", "test_vector_db_2") assert result_vector_db is not None @@ -127,13 +135,17 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): ) # Now we expect a ValueError to be raised for duplicate registration - with pytest.raises(ValueError, match=r".*already registered.*"): + with pytest.raises( + ValueError, match=r"Vector_db.*already registered.*provider.*baz.*" + ): await cached_disk_dist_registry.register(duplicate_vector_db) # Verify the original registration is still intact result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None - assert result.embedding_model == original_vector_db.embedding_model # Original values preserved + assert ( + result.embedding_model == original_vector_db.embedding_model + ) # Original values preserved async def test_get_all_objects(cached_disk_dist_registry): @@ -160,12 +172,17 @@ async def test_get_all_objects(cached_disk_dist_registry): # Verify each vector_db was stored correctly for original_vector_db in test_vector_dbs: - matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier] + matching_vector_dbs = [ + v for v in all_results if v.identifier == original_vector_db.identifier + ] assert len(matching_vector_dbs) == 1 stored_vector_db = matching_vector_dbs[0] assert stored_vector_db.embedding_model == original_vector_db.embedding_model assert stored_vector_db.provider_id == original_vector_db.provider_id - assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension + assert ( + stored_vector_db.embedding_dimension + == original_vector_db.embedding_dimension + ) async def test_parse_registry_values_error_handling(sqlite_kvstore): @@ -178,10 +195,14 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore): ) await sqlite_kvstore.set( - KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json() + KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), + valid_db.model_dump_json(), ) - await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json") + await sqlite_kvstore.set( + KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), + "{not valid json", + ) await sqlite_kvstore.set( KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), @@ -216,7 +237,8 @@ async def test_cached_registry_error_handling(sqlite_kvstore): ) await sqlite_kvstore.set( - KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json() + KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), + valid_db.model_dump_json(), ) await sqlite_kvstore.set( From 2d4775c67a0bd5ae2db6b20431970f00ba8c3dfc Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Tue, 30 Sep 2025 17:12:43 -0700 Subject: [PATCH 3/4] updated the error message and test_registry --- llama_stack/core/store/registry.py | 2 +- tests/unit/registry/test_registry.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 9f4d774e8..4c34f52e1 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -117,7 +117,7 @@ class DiskDistributionRegistry(DistributionRegistry): if existing_obj and existing_obj.provider_id == obj.provider_id: raise ValueError( f"{obj.type.title()} '{obj.identifier}' is already registered with provider '{obj.provider_id}'. " - f"Please unregister the existing {obj.type} first before registering a new one." + f"Unregister the existing object first before registering a new one." ) await self.kvstore.set( diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 83dcf0968..e39a421d5 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -136,7 +136,8 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): # Now we expect a ValueError to be raised for duplicate registration with pytest.raises( - ValueError, match=r"Vector_db.*already registered.*provider.*baz.*" + ValueError, + match=r"Vector_Db.*already registered.*provider.*baz.*Unregister the existing", ): await cached_disk_dist_registry.register(duplicate_vector_db) From d5d2061c8c0daf86b52850f726ccf541ca3b333c Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Tue, 30 Sep 2025 17:25:13 -0700 Subject: [PATCH 4/4] updated the error message and ran pre-commit --- llama_stack/core/store/registry.py | 44 ++++++++-------------------- tests/unit/registry/test_registry.py | 33 ++++++--------------- 2 files changed, 21 insertions(+), 56 deletions(-) diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 4c34f52e1..6f6e9dde2 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -31,9 +31,7 @@ class DistributionRegistry(Protocol): def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ... - async def update( - self, obj: RoutableObjectWithProvider - ) -> RoutableObjectWithProvider: ... + async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ... async def register(self, obj: RoutableObjectWithProvider) -> bool: ... @@ -59,9 +57,7 @@ def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value) all_objects.append(obj) except pydantic.ValidationError as e: - logger.error( - f"Error parsing registry value, raw value: {value}. Error: {e}" - ) + logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}") continue return all_objects @@ -74,9 +70,7 @@ class DiskDistributionRegistry(DistributionRegistry): async def initialize(self) -> None: pass - def get_cached( - self, type: str, identifier: str - ) -> RoutableObjectWithProvider | None: + def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: # Disk registry does not have a cache raise NotImplementedError("Disk registry does not have a cache") @@ -85,23 +79,15 @@ class DiskDistributionRegistry(DistributionRegistry): values = await self.kvstore.values_in_range(start_key, end_key) return _parse_registry_values(values) - async def get( - self, type: str, identifier: str - ) -> RoutableObjectWithProvider | None: - json_str = await self.kvstore.get( - KEY_FORMAT.format(type=type, identifier=identifier) - ) + async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier)) if not json_str: return None try: - return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json( - json_str - ) + return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str) except pydantic.ValidationError as e: - logger.error( - f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}" - ) + logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}") return None async def update(self, obj: RoutableObjectWithProvider) -> None: @@ -116,8 +102,8 @@ class DiskDistributionRegistry(DistributionRegistry): # dont register if the object's providerid already exists if existing_obj and existing_obj.provider_id == obj.provider_id: raise ValueError( - f"{obj.type.title()} '{obj.identifier}' is already registered with provider '{obj.provider_id}'. " - f"Unregister the existing object first before registering a new one." + f"Provider '{obj.provider_id}' is already registered." + f"Unregister the existing provider first before registering it again." ) await self.kvstore.set( @@ -167,9 +153,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async def initialize(self) -> None: await self._ensure_initialized() - def get_cached( - self, type: str, identifier: str - ) -> RoutableObjectWithProvider | None: + def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: return self.cache.get((type, identifier), None) async def get_all(self) -> list[RoutableObjectWithProvider]: @@ -177,9 +161,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async with self._locked_cache() as cache: return list(cache.values()) - async def get( - self, type: str, identifier: str - ) -> RoutableObjectWithProvider | None: + async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: await self._ensure_initialized() cache_key = (type, identifier) @@ -221,9 +203,7 @@ async def create_dist_registry( dist_kvstore = await kvstore_impl(metadata_store) else: dist_kvstore = await kvstore_impl( - SqliteKVStoreConfig( - db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix() - ) + SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()) ) dist_registry = CachedDiskDistributionRegistry(dist_kvstore) await dist_registry.initialize() diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index e39a421d5..e17855b40 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -10,9 +10,9 @@ import pytest from llama_stack.apis.inference import Model from llama_stack.apis.vector_dbs import VectorDB from llama_stack.core.store.registry import ( + KEY_FORMAT, CachedDiskDistributionRegistry, DiskDistributionRegistry, - KEY_FORMAT, ) from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @@ -62,9 +62,7 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m assert result_model.provider_id == sample_model.provider_id -async def test_cached_registry_initialization( - sqlite_kvstore, sample_vector_db, sample_model -): +async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model): # First populate the disk registry disk_registry = DiskDistributionRegistry(sqlite_kvstore) await disk_registry.initialize() @@ -73,9 +71,7 @@ async def test_cached_registry_initialization( # Test cached version loads from disk db_path = sqlite_kvstore.db_path - cached_registry = CachedDiskDistributionRegistry( - await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) - ) + cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) await cached_registry.initialize() result_vector_db = await cached_registry.get("vector_db", "test_vector_db") @@ -97,18 +93,14 @@ async def test_cached_registry_updates(cached_disk_dist_registry): await cached_disk_dist_registry.register(new_vector_db) # Verify in cache - result_vector_db = await cached_disk_dist_registry.get( - "vector_db", "test_vector_db_2" - ) + result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result_vector_db is not None assert result_vector_db.identifier == new_vector_db.identifier assert result_vector_db.provider_id == new_vector_db.provider_id # Verify persisted to disk db_path = cached_disk_dist_registry.kvstore.db_path - new_registry = DiskDistributionRegistry( - await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) - ) + new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) await new_registry.initialize() result_vector_db = await new_registry.get("vector_db", "test_vector_db_2") assert result_vector_db is not None @@ -137,16 +129,14 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): # Now we expect a ValueError to be raised for duplicate registration with pytest.raises( ValueError, - match=r"Vector_Db.*already registered.*provider.*baz.*Unregister the existing", + match=r"Provider 'baz' is already registered.*Unregister the existing provider first before registering it again.", ): await cached_disk_dist_registry.register(duplicate_vector_db) # Verify the original registration is still intact result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None - assert ( - result.embedding_model == original_vector_db.embedding_model - ) # Original values preserved + assert result.embedding_model == original_vector_db.embedding_model # Original values preserved async def test_get_all_objects(cached_disk_dist_registry): @@ -173,17 +163,12 @@ async def test_get_all_objects(cached_disk_dist_registry): # Verify each vector_db was stored correctly for original_vector_db in test_vector_dbs: - matching_vector_dbs = [ - v for v in all_results if v.identifier == original_vector_db.identifier - ] + matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier] assert len(matching_vector_dbs) == 1 stored_vector_db = matching_vector_dbs[0] assert stored_vector_db.embedding_model == original_vector_db.embedding_model assert stored_vector_db.provider_id == original_vector_db.provider_id - assert ( - stored_vector_db.embedding_dimension - == original_vector_db.embedding_dimension - ) + assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension async def test_parse_registry_values_error_handling(sqlite_kvstore):