fix: handle registry errors gracefully

This commit is contained in:
Ashwin Bharambe 2025-03-20 15:06:39 -07:00
parent 86f617a197
commit 0965fcb899
2 changed files with 85 additions and 3 deletions

View file

@ -12,6 +12,7 @@ import pytest_asyncio
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.store.registry import (
KEY_FORMAT,
CachedDiskDistributionRegistry,
DiskDistributionRegistry,
)
@ -197,3 +198,72 @@ async def test_get_all_objects(config):
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
assert stored_vector_db.provider_id == original_vector_db.provider_id
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
@pytest.mark.asyncio
async def test_parse_registry_values_error_handling(config):
kvstore = await kvstore_impl(config)
valid_db = VectorDB(
identifier="valid_vector_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="valid_vector_db",
provider_id="test-provider",
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json())
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
'{"type": "vector_db", "identifier": "missing_fields"}',
)
test_registry = DiskDistributionRegistry(kvstore)
await test_registry.initialize()
# Get all objects, which should only return the valid one
all_objects = await test_registry.get_all()
# Should have filtered out the invalid entries
assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_vector_db"
# Check that the get method also handles errors correctly
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
assert invalid_obj is None
invalid_obj = await test_registry.get("vector_db", "missing_fields")
assert invalid_obj is None
@pytest.mark.asyncio
async def test_cached_registry_error_handling(config):
kvstore = await kvstore_impl(config)
valid_db = VectorDB(
identifier="valid_cached_db",
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_resource_id="valid_cached_db",
provider_id="test-provider",
)
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json())
await kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
)
cached_registry = CachedDiskDistributionRegistry(kvstore)
await cached_registry.initialize()
all_objects = await cached_registry.get_all()
assert len(all_objects) == 1
assert all_objects[0].identifier == "valid_cached_db"
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
assert invalid_obj is None