workign tests

This commit is contained in:
Dinesh Yeduguru 2024-11-01 14:02:45 -07:00 committed by Dinesh Yeduguru
parent 7696f31284
commit 4b6367838f
2 changed files with 24 additions and 4 deletions

View file

@ -33,7 +33,7 @@ class DiskRegistry(Registry):
return [
pydantic.parse_obj_as(
RoutableObjectWithProvider,
obj_str,
json.loads(obj_str),
)
for obj_str in objects_data
]

View file

@ -1,3 +1,5 @@
import os
import pytest
import pytest_asyncio
from llama_stack.distribution.store import *
@ -9,7 +11,11 @@ from llama_stack.distribution.datatypes import * # noqa: F403
@pytest.mark.asyncio
async def test_registry():
registry = DiskRegistry(await kvstore_impl(SqliteKVStoreConfig()))
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
# delete the file if it exists
if os.path.exists(config.db_path):
os.remove(config.db_path)
registry = DiskRegistry(await kvstore_impl(config))
bank = VectorMemoryBankDef(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
@ -17,12 +23,26 @@ async def test_registry():
overlap_size_in_tokens=64,
provider_id="bar",
)
model = ModelDefWithProvider(
identifier="test_model",
llama_model="Llama3.2-3B-Instruct",
provider_id="foo",
)
await registry.register(bank)
result_bank = await registry.get("test_bank")
# assert result_bank == bank
await registry.register(model)
results = await registry.get("test_bank")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == bank.identifier
assert result_bank.embedding_model == bank.embedding_model
assert result_bank.chunk_size_in_tokens == bank.chunk_size_in_tokens
assert result_bank.overlap_size_in_tokens == bank.overlap_size_in_tokens
assert result_bank.provider_id == bank.provider_id
results = await registry.get("test_model")
assert len(results) == 1
result_model = results[0]
assert result_model.identifier == model.identifier
assert result_model.llama_model == model.llama_model
assert result_model.provider_id == model.provider_id