mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
fix: handle registry errors gracefully
This commit is contained in:
parent
86f617a197
commit
0965fcb899
2 changed files with 85 additions and 3 deletions
|
@ -12,9 +12,12 @@ import pydantic
|
|||
|
||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
class DistributionRegistry(Protocol):
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
||||
|
@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
|
|||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||
all_objects = []
|
||||
for value in values:
|
||||
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
||||
all_objects.append(obj)
|
||||
try:
|
||||
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
||||
all_objects.append(obj)
|
||||
except pydantic.ValidationError as e:
|
||||
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
|
||||
continue
|
||||
|
||||
return all_objects
|
||||
|
||||
|
||||
|
@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
if not json_str:
|
||||
return None
|
||||
|
||||
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
||||
try:
|
||||
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
||||
except pydantic.ValidationError as e:
|
||||
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
|
||||
return None
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.kvstore.set(
|
||||
|
|
|
@ -12,6 +12,7 @@ import pytest_asyncio
|
|||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.store.registry import (
|
||||
KEY_FORMAT,
|
||||
CachedDiskDistributionRegistry,
|
||||
DiskDistributionRegistry,
|
||||
)
|
||||
|
@ -197,3 +198,72 @@ async def test_get_all_objects(config):
|
|||
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
|
||||
assert stored_vector_db.provider_id == original_vector_db.provider_id
|
||||
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_registry_values_error_handling(config):
|
||||
kvstore = await kvstore_impl(config)
|
||||
|
||||
valid_db = VectorDB(
|
||||
identifier="valid_vector_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="valid_vector_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json())
|
||||
|
||||
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
||||
|
||||
await kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
||||
'{"type": "vector_db", "identifier": "missing_fields"}',
|
||||
)
|
||||
|
||||
test_registry = DiskDistributionRegistry(kvstore)
|
||||
await test_registry.initialize()
|
||||
|
||||
# Get all objects, which should only return the valid one
|
||||
all_objects = await test_registry.get_all()
|
||||
|
||||
# Should have filtered out the invalid entries
|
||||
assert len(all_objects) == 1
|
||||
assert all_objects[0].identifier == "valid_vector_db"
|
||||
|
||||
# Check that the get method also handles errors correctly
|
||||
invalid_obj = await test_registry.get("vector_db", "corrupted_json")
|
||||
assert invalid_obj is None
|
||||
|
||||
invalid_obj = await test_registry.get("vector_db", "missing_fields")
|
||||
assert invalid_obj is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_registry_error_handling(config):
|
||||
kvstore = await kvstore_impl(config)
|
||||
|
||||
valid_db = VectorDB(
|
||||
identifier="valid_cached_db",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
provider_resource_id="valid_cached_db",
|
||||
provider_id="test-provider",
|
||||
)
|
||||
|
||||
await kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json())
|
||||
|
||||
await kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="invalid_cached_db"),
|
||||
'{"type": "vector_db", "identifier": "invalid_cached_db", "embedding_model": 12345}', # Should be string
|
||||
)
|
||||
|
||||
cached_registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await cached_registry.initialize()
|
||||
|
||||
all_objects = await cached_registry.get_all()
|
||||
assert len(all_objects) == 1
|
||||
assert all_objects[0].identifier == "valid_cached_db"
|
||||
|
||||
invalid_obj = await cached_registry.get("vector_db", "invalid_cached_db")
|
||||
assert invalid_obj is None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue