mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
fix: handle registry errors gracefully
This commit is contained in:
parent
86f617a197
commit
0965fcb899
2 changed files with 85 additions and 3 deletions
|
@ -12,9 +12,12 @@ import pydantic
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class DistributionRegistry(Protocol):
|
class DistributionRegistry(Protocol):
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
||||||
|
@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
|
||||||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||||
all_objects = []
|
all_objects = []
|
||||||
for value in values:
|
for value in values:
|
||||||
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
try:
|
||||||
all_objects.append(obj)
|
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
||||||
|
all_objects.append(obj)
|
||||||
|
except pydantic.ValidationError as e:
|
||||||
|
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
return all_objects
|
return all_objects
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
if not json_str:
|
if not json_str:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
try:
|
||||||
|
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
||||||
|
except pydantic.ValidationError as e:
|
||||||
|
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
|
|
|
@ -12,6 +12,7 @@ import pytest_asyncio
|
||||||
from llama_stack.apis.inference import Model
|
from llama_stack.apis.inference import Model
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.distribution.store.registry import (
|
from llama_stack.distribution.store.registry import (
|
||||||
|
KEY_FORMAT,
|
||||||
CachedDiskDistributionRegistry,
|
CachedDiskDistributionRegistry,
|
||||||
DiskDistributionRegistry,
|
DiskDistributionRegistry,
|
||||||
)
|
)
|
||||||
|
@ -197,3 +198,72 @@ async def test_get_all_objects(config):
|
||||||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
||||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
||||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_registry_values_error_handling(config):
|
||||||
|
kvstore = await kvstore_impl(config)
|
||||||
|
|
||||||
|
valid_db = VectorDB(
|
||||||
|
identifier="valid_vector_db",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
embedding_dimension=384,
|
||||||
|
provider_resource_id="valid_vector_db",
|
||||||
|
provider_id="test-provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json())
|
||||||
|
|
||||||
|
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
||||||
|
|
||||||
|
await kvstore.set(
|
||||||
|
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
||||||
|
'{"type": "vector_db", "identifier": "missing_fields"}',
|
||||||
|
)
|
||||||
|
|
||||||
|
test_registry = DiskDistributionRegistry(kvstore)
|
||||||
|
await test_registry.initialize()
|
||||||
|
|
||||||
|
# Get all objects, which should only return the valid one
|
||||||
|
all_objects = await test_registry.get_all()
|
||||||
|
|
||||||
|
# Should have filtered out the invalid entries
|
||||||
|
assert len(all_objects) == 1
|
||||||
|
assert all_objects[0].identifier == "valid_vector_db"
|
||||||
|
|
||||||
|
# Check that the get method also handles errors correctly
|
||||||
|
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
|
||||||
|
assert invalid_obj is None
|
||||||
|
|
||||||
|
invalid_obj = await test_registry.get("vector_db", "missing_fields")
|
||||||
|
assert invalid_obj is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cached_registry_error_handling(config):
|
||||||
|
kvstore = await kvstore_impl(config)
|
||||||
|
|
||||||
|
valid_db = VectorDB(
|
||||||
|
identifier="valid_cached_db",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
embedding_dimension=384,
|
||||||
|
provider_resource_id="valid_cached_db",
|
||||||
|
provider_id="test-provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json())
|
||||||
|
|
||||||
|
await kvstore.set(
|
||||||
|
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
|
||||||
|
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
||||||
|
)
|
||||||
|
|
||||||
|
cached_registry = CachedDiskDistributionRegistry(kvstore)
|
||||||
|
await cached_registry.initialize()
|
||||||
|
|
||||||
|
all_objects = await cached_registry.get_all()
|
||||||
|
assert len(all_objects) == 1
|
||||||
|
assert all_objects[0].identifier == "valid_cached_db"
|
||||||
|
|
||||||
|
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
|
||||||
|
assert invalid_obj is None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue