From 5943ce404be12bc7448ca148b3ce2cb92344fcd2 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Tue, 30 Sep 2025 17:03:46 -0700 Subject: [PATCH] updated the error message --- llama_stack/core/store/registry.py | 49 ++++++++++++++++++++-------- tests/unit/registry/test_registry.py | 46 +++++++++++++++++++------- 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 5491ab0f1..9f4d774e8 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -14,7 +14,10 @@ from llama_stack.core.datatypes import RoutableObjectWithProvider from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl -from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) logger = get_logger(__name__, category="core::registry") @@ -28,7 +31,9 @@ class DistributionRegistry(Protocol): def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ... - async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ... + async def update( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: ... async def register(self, obj: RoutableObjectWithProvider) -> bool: ... @@ -54,7 +59,9 @@ def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value) all_objects.append(obj) except pydantic.ValidationError as e: - logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}") + logger.error( + f"Error parsing registry value, raw value: {value}. Error: {e}" + ) continue return all_objects @@ -67,7 +74,9 @@ class DiskDistributionRegistry(DistributionRegistry): async def initialize(self) -> None: pass - def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + def get_cached( + self, type: str, identifier: str + ) -> RoutableObjectWithProvider | None: # Disk registry does not have a cache raise NotImplementedError("Disk registry does not have a cache") @@ -76,15 +85,23 @@ class DiskDistributionRegistry(DistributionRegistry): values = await self.kvstore.values_in_range(start_key, end_key) return _parse_registry_values(values) - async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: - json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier)) + async def get( + self, type: str, identifier: str + ) -> RoutableObjectWithProvider | None: + json_str = await self.kvstore.get( + KEY_FORMAT.format(type=type, identifier=identifier) + ) if not json_str: return None try: - return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str) + return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json( + json_str + ) except pydantic.ValidationError as e: - logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}") + logger.error( + f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}" + ) return None async def update(self, obj: RoutableObjectWithProvider) -> None: @@ -99,8 +116,8 @@ class DiskDistributionRegistry(DistributionRegistry): # dont register if the object's providerid already exists if existing_obj and existing_obj.provider_id == obj.provider_id: raise ValueError( - f"Model '{obj.identifier}' is already registered with provider '{obj.provider_id}'. " - f"Please unregister the existing model first before registering a new one." + f"{obj.type.title()} '{obj.identifier}' is already registered with provider '{obj.provider_id}'. " + f"Please unregister the existing {obj.type} first before registering a new one." ) await self.kvstore.set( @@ -150,7 +167,9 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async def initialize(self) -> None: await self._ensure_initialized() - def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + def get_cached( + self, type: str, identifier: str + ) -> RoutableObjectWithProvider | None: return self.cache.get((type, identifier), None) async def get_all(self) -> list[RoutableObjectWithProvider]: @@ -158,7 +177,9 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async with self._locked_cache() as cache: return list(cache.values()) - async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + async def get( + self, type: str, identifier: str + ) -> RoutableObjectWithProvider | None: await self._ensure_initialized() cache_key = (type, identifier) @@ -200,7 +221,9 @@ async def create_dist_registry( dist_kvstore = await kvstore_impl(metadata_store) else: dist_kvstore = await kvstore_impl( - SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()) + SqliteKVStoreConfig( + db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix() + ) ) dist_registry = CachedDiskDistributionRegistry(dist_kvstore) await dist_registry.initialize() diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 5d8add4bf..83dcf0968 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -10,9 +10,9 @@ import pytest from llama_stack.apis.inference import Model from llama_stack.apis.vector_dbs import VectorDB from llama_stack.core.store.registry import ( - KEY_FORMAT, CachedDiskDistributionRegistry, DiskDistributionRegistry, + KEY_FORMAT, ) from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @@ -62,7 +62,9 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m assert result_model.provider_id == sample_model.provider_id -async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model): +async def test_cached_registry_initialization( + sqlite_kvstore, sample_vector_db, sample_model +): # First populate the disk registry disk_registry = DiskDistributionRegistry(sqlite_kvstore) await disk_registry.initialize() @@ -71,7 +73,9 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, # Test cached version loads from disk db_path = sqlite_kvstore.db_path - cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) + cached_registry = CachedDiskDistributionRegistry( + await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) + ) await cached_registry.initialize() result_vector_db = await cached_registry.get("vector_db", "test_vector_db") @@ -93,14 +97,18 @@ async def test_cached_registry_updates(cached_disk_dist_registry): await cached_disk_dist_registry.register(new_vector_db) # Verify in cache - result_vector_db = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") + result_vector_db = await cached_disk_dist_registry.get( + "vector_db", "test_vector_db_2" + ) assert result_vector_db is not None assert result_vector_db.identifier == new_vector_db.identifier assert result_vector_db.provider_id == new_vector_db.provider_id # Verify persisted to disk db_path = cached_disk_dist_registry.kvstore.db_path - new_registry = DiskDistributionRegistry(await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))) + new_registry = DiskDistributionRegistry( + await kvstore_impl(SqliteKVStoreConfig(db_path=db_path)) + ) await new_registry.initialize() result_vector_db = await new_registry.get("vector_db", "test_vector_db_2") assert result_vector_db is not None @@ -127,13 +135,17 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry): ) # Now we expect a ValueError to be raised for duplicate registration - with pytest.raises(ValueError, match=r".*already registered.*"): + with pytest.raises( + ValueError, match=r"Vector_db.*already registered.*provider.*baz.*" + ): await cached_disk_dist_registry.register(duplicate_vector_db) # Verify the original registration is still intact result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") assert result is not None - assert result.embedding_model == original_vector_db.embedding_model # Original values preserved + assert ( + result.embedding_model == original_vector_db.embedding_model + ) # Original values preserved async def test_get_all_objects(cached_disk_dist_registry): @@ -160,12 +172,17 @@ async def test_get_all_objects(cached_disk_dist_registry): # Verify each vector_db was stored correctly for original_vector_db in test_vector_dbs: - matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier] + matching_vector_dbs = [ + v for v in all_results if v.identifier == original_vector_db.identifier + ] assert len(matching_vector_dbs) == 1 stored_vector_db = matching_vector_dbs[0] assert stored_vector_db.embedding_model == original_vector_db.embedding_model assert stored_vector_db.provider_id == original_vector_db.provider_id - assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension + assert ( + stored_vector_db.embedding_dimension + == original_vector_db.embedding_dimension + ) async def test_parse_registry_values_error_handling(sqlite_kvstore): @@ -178,10 +195,14 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore): ) await sqlite_kvstore.set( - KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json() + KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), + valid_db.model_dump_json(), ) - await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json") + await sqlite_kvstore.set( + KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), + "{not valid json", + ) await sqlite_kvstore.set( KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), @@ -216,7 +237,8 @@ async def test_cached_registry_error_handling(sqlite_kvstore): ) await sqlite_kvstore.set( - KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json() + KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), + valid_db.model_dump_json(), ) await sqlite_kvstore.set(