From f95bc29ca93be5b46819bcd984792120036030a5 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Mar 2025 15:24:07 -0700 Subject: [PATCH] fix: handle registry errors gracefully (#1732) We need to be able to handle stale registry entries gracefully. More needs to be done when we are deleting important attributes from resources which could have been persisted. But at the very least, the server cannot die. ## Test Plan Added unit tests --- llama_stack/distribution/store/registry.py | 18 +++++- tests/unit/registry/test_registry.py | 70 ++++++++++++++++++++++ 2 files changed, 85 insertions(+), 3 deletions(-) diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index ef770ff72..76b66cc7a 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -12,9 +12,12 @@ import pydantic from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +logger = get_logger(__name__, category="core") + class DistributionRegistry(Protocol): async def get_all(self) -> List[RoutableObjectWithProvider]: ... @@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider """Utility function to parse registry values into RoutableObjectWithProvider objects.""" all_objects = [] for value in values: - obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value) - all_objects.append(obj) + try: + obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value) + all_objects.append(obj) + except pydantic.ValidationError as e: + logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}") + continue + return all_objects @@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry): if not json_str: return None - return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str) + try: + return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str) + except pydantic.ValidationError as e: + logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}") + return None async def update(self, obj: RoutableObjectWithProvider) -> None: await self.kvstore.set( diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 1ddba7472..9896b3212 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -12,6 +12,7 @@ import pytest_asyncio from llama_stack.apis.inference import Model from llama_stack.apis.vector_dbs import VectorDB from llama_stack.distribution.store.registry import ( + KEY_FORMAT, CachedDiskDistributionRegistry, DiskDistributionRegistry, ) @@ -197,3 +198,72 @@ async def test_get_all_objects(config): assert stored_vector_db.embedding_model == original_vector_db.embedding_model assert stored_vector_db.provider_id == original_vector_db.provider_id assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension + + +@pytest.mark.asyncio +async def test_parse_registry_values_error_handling(config): + kvstore = await kvstore_impl(config) + + valid_db = VectorDB( + identifier="valid_vector_db", + embedding_model="all-MiniLM-L6-v2", + embedding_dimension=384, + provider_resource_id="valid_vector_db", + provider_id="test-provider", + ) + + await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()) + + await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json") + + await kvstore.set( + KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), + '{"type": "vector_db", "identifier": "missing_fields"}', + ) + + test_registry = DiskDistributionRegistry(kvstore) + await test_registry.initialize() + + # Get all objects, which should only return the valid one + all_objects = await test_registry.get_all() + + # Should have filtered out the invalid entries + assert len(all_objects) == 1 + assert all_objects[0].identifier == "valid_vector_db" + + # Check that the get method also handles errors correctly + invalid_obj = await test_registry.get("vector_db", "corrupted_json") + assert invalid_obj is None + + invalid_obj = await test_registry.get("vector_db", "missing_fields") + assert invalid_obj is None + + +@pytest.mark.asyncio +async def test_cached_registry_error_handling(config): + kvstore = await kvstore_impl(config) + + valid_db = VectorDB( + identifier="valid_cached_db", + embedding_model="all-MiniLM-L6-v2", + embedding_dimension=384, + provider_resource_id="valid_cached_db", + provider_id="test-provider", + ) + + await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()) + + await kvstore.set( + KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"), + '{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string + ) + + cached_registry = CachedDiskDistributionRegistry(kvstore) + await cached_registry.initialize() + + all_objects = await cached_registry.get_all() + assert len(all_objects) == 1 + assert all_objects[0].identifier == "valid_cached_db" + + invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db") + assert invalid_obj is None