fix: handle registry errors gracefully

This commit is contained in:
Ashwin Bharambe 2025-03-20 15:06:39 -07:00
parent 86f617a197
commit 0965fcb899
2 changed files with 85 additions and 3 deletions

View file

@ -12,9 +12,12 @@ import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
class DistributionRegistry(Protocol):
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
try:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
continue
return all_objects
@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry):
if not json_str:
return None
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
try:
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
return None
async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set(

View file

@ -12,6 +12,7 @@ import pytest_asyncio
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.store.registry import (
KEY_FORMAT,
CachedDiskDistributionRegistry,
DiskDistributionRegistry,
)
@ -197,3 +198,72 @@ async def test_get_all_objects(config):
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
assert stored_vector_db.provider_id == original_vector_db.provider_id
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
@pytest.mark.asyncio
async def test_parse_registry_values_error_handling(config):
kvstore = await kvstore_impl(config)
valid_db = VectorDB(
identifier="valid_vector_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="valid_vector_db",
provider_id="test-provider",
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json())
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
'{"type": "vector_db", "identifier": "missing_fields"}',
)
test_registry = DiskDistributionRegistry(kvstore)
await test_registry.initialize()
# Get all objects, which should only return the valid one
all_objects = await test_registry.get_all()
# Should have filtered out the invalid entries
assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_vector_db"
# Check that the get method also handles errors correctly
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
assert invalid_obj is None
invalid_obj = await test_registry.get("vector_db", "missing_fields")
assert invalid_obj is None
@pytest.mark.asyncio
async def test_cached_registry_error_handling(config):
kvstore = await kvstore_impl(config)
valid_db = VectorDB(
identifier="valid_cached_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="valid_cached_db",
provider_id="test-provider",
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json())
await kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
)
cached_registry = CachedDiskDistributionRegistry(kvstore)
await cached_registry.initialize()
all_objects = await cached_registry.get_all()
assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_cached_db"
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
assert invalid_obj is None