diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 4c34f52e1..6f6e9dde2 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -31,9 +31,7 @@ class DistributionRegistry(Protocol): def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ... - async def update( - self, obj: RoutableObjectWithProvider - ) -> RoutableObjectWithProvider: ... + async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ... async def register(self, obj: RoutableObjectWithProvider) -> bool: ... @@ -59,9 +57,7 @@ def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value) all_objects.append(obj) except pydantic.ValidationError as e: - logger.error( - f"Error parsing registry value, raw value: {value}. Error: {e}" - ) + logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}") continue return all_objects @@ -74,9 +70,7 @@ class DiskDistributionRegistry(DistributionRegistry): async def initialize(self) -> None: pass - def get_cached( - self, type: str, identifier: str - ) -> RoutableObjectWithProvider | None: + def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: # Disk registry does not have a cache raise NotImplementedError("Disk registry does not have a cache") @@ -85,23 +79,15 @@ class DiskDistributionRegistry(DistributionRegistry): values = await self.kvstore.values_in_range(start_key, end_key) return _parse_registry_values(values) - async def get( - self, type: str, identifier: str - ) -> RoutableObjectWithProvider | None: - json_str = await self.kvstore.get( - KEY_FORMAT.format(type=type, identifier=identifier) - ) + async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier)) if not json_str: return None try: - return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json( - json_str - ) + return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str) except pydantic.ValidationError as e: - logger.error( - f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}" - ) + logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}") return None async def update(self, obj: RoutableObjectWithProvider) -> None: @@ -116,8 +102,8 @@ class DiskDistributionRegistry(DistributionRegistry): # dont register if the object's providerid already exists if existing_obj and existing_obj.provider_id == obj.provider_id: raise ValueError( - f"{obj.type.title()} '{obj.identifier}' is already registered with provider '{obj.provider_id}'. " - f"Unregister the existing object first before registering a new one." + f"Provider '{obj.provider_id}' is already registered." + f"Unregister the existing provider first before registering it again." ) await self.kvstore.set( @@ -167,9 +153,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async def initialize(self) -> None: await self._ensure_initialized() - def get_cached( - self, type: str, identifier: str - ) -> RoutableObjectWithProvider | None: + def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: return self.cache.get((type, identifier), None) async def get_all(self) -> list[RoutableObjectWithProvider]: @@ -177,9 +161,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async with self._locked_cache() as cache: return list(cache.values()) - async def get( - self, type: str, identifier: str - ) -> RoutableObjectWithProvider | None: + async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: await self._ensure_initialized() cache_key = (type, identifier) @@ -221,9 +203,7 @@ async def create_dist_registry( dist_kvstore = await kvstore_impl(metadata_store) else: dist_kvstore = await kvstore_impl( - SqliteKVStoreConfig( - db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix() - ) + SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()) ) dist_registry = CachedDiskDistributionRegistry(dist_kvstore) await dist_registry.initialize() diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index e39a421d5..e17855b40 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -10,9 +10,9 @@ import pytest from llama_stack.apis.inference import Model from llama_stack.apis.vector_dbs import VectorDB from llama_stack.core.store.registry import ( + KEY_FORMAT, CachedDiskDistributionRegistry, DiskDistributionRegistry, - KEY_FORMAT, ) from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @@ -62,9 +62,7 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m assert result_model.provider_id == sample_model.provider_id -async def test_cached_registry_initialization( - sqlite_kvstore, sample_vector_db, sample_model -): +async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model): # First populate the disk registry disk_registry = DiskDistributionRegistry(sqlite_kvstore) await disk_registry.initialize() @@ -73,9 +71,7 @@ async def test_cached_registry_initialization( # Test cached version loads from disk db_path = sqlite_kvstore.db_path - cached_registry = CachedDiskDistributionRegistry( - await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) - ) + cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) await cached_registry.initialize() result_vector_db = await cached_registry.get("vector_db", "test_vector_db") @@ -97,18 +93,14 @@ async def test_cached_registry_updates(cached_disk_dist_registry): await cached_disk_dist_registry.register(new_vector_db) # Verify in cache - result_vector_db = await cached_disk_dist_registry.get( - "vector_db", "test_vector_db_2" - ) + result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result_vector_db is not None assert result_vector_db.identifier == new_vector_db.identifier assert result_vector_db.provider_id == new_vector_db.provider_id # Verify persisted to disk db_path = cached_disk_dist_registry.kvstore.db_path - new_registry = DiskDistributionRegistry( - await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) - ) + new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) await new_registry.initialize() result_vector_db = await new_registry.get("vector_db", "test_vector_db_2") assert result_vector_db is not None @@ -137,16 +129,14 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): # Now we expect a ValueError to be raised for duplicate registration with pytest.raises( ValueError, - match=r"Vector_Db.*already registered.*provider.*baz.*Unregister the existing", + match=r"Provider 'baz' is already registered.*Unregister the existing provider first before registering it again.", ): await cached_disk_dist_registry.register(duplicate_vector_db) # Verify the original registration is still intact result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None - assert ( - result.embedding_model == original_vector_db.embedding_model - ) # Original values preserved + assert result.embedding_model == original_vector_db.embedding_model # Original values preserved async def test_get_all_objects(cached_disk_dist_registry): @@ -173,17 +163,12 @@ async def test_get_all_objects(cached_disk_dist_registry): # Verify each vector_db was stored correctly for original_vector_db in test_vector_dbs: - matching_vector_dbs = [ - v for v in all_results if v.identifier == original_vector_db.identifier - ] + matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier] assert len(matching_vector_dbs) == 1 stored_vector_db = matching_vector_dbs[0] assert stored_vector_db.embedding_model == original_vector_db.embedding_model assert stored_vector_db.provider_id == original_vector_db.provider_id - assert ( - stored_vector_db.embedding_dimension - == original_vector_db.embedding_dimension - ) + assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension async def test_parse_registry_values_error_handling(sqlite_kvstore):