diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index ef770ff72..76b66cc7a 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -12,9 +12,12 @@ import pydantic from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +logger = get_logger(__name__, category="core") + class DistributionRegistry(Protocol): async def get_all(self) -> List[RoutableObjectWithProvider]: ... @@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider """Utility function to parse registry values into RoutableObjectWithProvider objects.""" all_objects = [] for value in values: - obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value) - all_objects.append(obj) + try: + obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value) + all_objects.append(obj) + except pydantic.ValidationError as e: + logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}") + continue + return all_objects @@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry): if not json_str: return None - return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str) + try: + return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str) + except pydantic.ValidationError as e: + logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}") + return None async def update(self, obj: RoutableObjectWithProvider) -> None: await self.kvstore.set( diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 1ddba7472..9896b3212 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -12,6 +12,7 @@ import pytest_asyncio from llama_stack.apis.inference import Model from llama_stack.apis.vector_dbs import VectorDB from llama_stack.distribution.store.registry import ( + KEY_FORMAT, CachedDiskDistributionRegistry, DiskDistributionRegistry, ) @@ -197,3 +198,72 @@ async def test_get_all_objects(config): assert stored_vector_db.embedding_model == original_vector_db.embedding_model assert stored_vector_db.provider_id == original_vector_db.provider_id assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension + + +@pytest.mark.asyncio +async def test_parse_registry_values_error_handling(config): + kvstore = await kvstore_impl(config) + + valid_db = VectorDB( + identifier="valid_vector_db", + embedding_model="all-MiniLM-L6-v2", + embedding_dimension=384, + provider_resource_id="valid_vector_db", + provider_id="test-provider", + ) + + await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()) + + await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json") + + await kvstore.set( + KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), + '{"type": "vector_db", "identifier": "missing_fields"}', + ) + + test_registry = DiskDistributionRegistry(kvstore) + await test_registry.initialize() + + # Get all objects, which should only return the valid one + all_objects = await test_registry.get_all() + + # Should have filtered out the invalid entries + assert len(all_objects) == 1 + assert all_objects[0].identifier == "valid_vector_db" + + # Check that the get method also handles errors correctly + invalid_obj = await test_registry.get("vector_db", "corrupted_json") + assert invalid_obj is None + + invalid_obj = await test_registry.get("vector_db", "missing_fields") + assert invalid_obj is None + + +@pytest.mark.asyncio +async def test_cached_registry_error_handling(config): + kvstore = await kvstore_impl(config) + + valid_db = VectorDB( + identifier="valid_cached_db", + embedding_model="all-MiniLM-L6-v2", + embedding_dimension=384, + provider_resource_id="valid_cached_db", + provider_id="test-provider", + ) + + await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()) + + await kvstore.set( + KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"), + '{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string + ) + + cached_registry = CachedDiskDistributionRegistry(kvstore) + await cached_registry.initialize() + + all_objects = await cached_registry.get_all() + assert len(all_objects) == 1 + assert all_objects[0].identifier == "valid_cached_db" + + invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db") + assert invalid_obj is None