diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 35ccd5178..28013c5e3 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -33,7 +33,7 @@ class DiskRegistry(Registry): return [ pydantic.parse_obj_as( RoutableObjectWithProvider, - obj_str, + json.loads(obj_str), ) for obj_str in objects_data ] diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py index cd9b07a79..8ddb61945 100644 --- a/llama_stack/distribution/store/tests/test_registry.py +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -1,3 +1,5 @@ +import os + import pytest import pytest_asyncio from llama_stack.distribution.store import * @@ -9,7 +11,11 @@ from llama_stack.distribution.datatypes import * # noqa: F403 @pytest.mark.asyncio async def test_registry(): - registry = DiskRegistry(await kvstore_impl(SqliteKVStoreConfig())) + config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db") + # delete the file if it exists + if os.path.exists(config.db_path): + os.remove(config.db_path) + registry = DiskRegistry(await kvstore_impl(config)) bank = VectorMemoryBankDef( identifier="test_bank", embedding_model="all-MiniLM-L6-v2", @@ -17,12 +23,26 @@ async def test_registry(): overlap_size_in_tokens=64, provider_id="bar", ) + model = ModelDefWithProvider( + identifier="test_model", + llama_model="Llama3.2-3B-Instruct", + provider_id="foo", + ) await registry.register(bank) - result_bank = await registry.get("test_bank") - # assert result_bank == bank + await registry.register(model) + results = await registry.get("test_bank") + assert len(results) == 1 + result_bank = results[0] assert result_bank.identifier == bank.identifier assert result_bank.embedding_model == bank.embedding_model assert result_bank.chunk_size_in_tokens == bank.chunk_size_in_tokens assert result_bank.overlap_size_in_tokens == bank.overlap_size_in_tokens assert result_bank.provider_id == bank.provider_id + + results = await registry.get("test_model") + assert len(results) == 1 + result_model = results[0] + assert result_model.identifier == model.identifier + assert result_model.llama_model == model.llama_model + assert result_model.provider_id == model.provider_id