workign tests

This commit is contained in:
Dinesh Yeduguru 2024-11-01 14:02:45 -07:00 committed by Dinesh Yeduguru
parent 7696f31284
commit 4b6367838f
2 changed files with 24 additions and 4 deletions

View file

@ -33,7 +33,7 @@ class DiskRegistry(Registry):
return [ return [
pydantic.parse_obj_as( pydantic.parse_obj_as(
RoutableObjectWithProvider, RoutableObjectWithProvider,
obj_str, json.loads(obj_str),
) )
for obj_str in objects_data for obj_str in objects_data
] ]

View file

@ -1,3 +1,5 @@
import os
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.store import * from llama_stack.distribution.store import *
@ -9,7 +11,11 @@ from llama_stack.distribution.datatypes import * # noqa: F403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_registry(): async def test_registry():
registry = DiskRegistry(await kvstore_impl(SqliteKVStoreConfig())) config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
# delete the file if it exists
if os.path.exists(config.db_path):
os.remove(config.db_path)
registry = DiskRegistry(await kvstore_impl(config))
bank = VectorMemoryBankDef( bank = VectorMemoryBankDef(
identifier="test_bank", identifier="test_bank",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
@ -17,12 +23,26 @@ async def test_registry():
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
provider_id="bar", provider_id="bar",
) )
model = ModelDefWithProvider(
identifier="test_model",
llama_model="Llama3.2-3B-Instruct",
provider_id="foo",
)
await registry.register(bank) await registry.register(bank)
result_bank = await registry.get("test_bank") await registry.register(model)
# assert result_bank == bank results = await registry.get("test_bank")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == bank.identifier assert result_bank.identifier == bank.identifier
assert result_bank.embedding_model == bank.embedding_model assert result_bank.embedding_model == bank.embedding_model
assert result_bank.chunk_size_in_tokens == bank.chunk_size_in_tokens assert result_bank.chunk_size_in_tokens == bank.chunk_size_in_tokens
assert result_bank.overlap_size_in_tokens == bank.overlap_size_in_tokens assert result_bank.overlap_size_in_tokens == bank.overlap_size_in_tokens
assert result_bank.provider_id == bank.provider_id assert result_bank.provider_id == bank.provider_id
results = await registry.get("test_model")
assert len(results) == 1
result_model = results[0]
assert result_model.identifier == model.identifier
assert result_model.llama_model == model.llama_model
assert result_model.provider_id == model.provider_id