llama-stack-mirror/llama_stack/distribution/store/tests/test_registry.py
Dinesh Yeduguru 4b6367838f workign tests
2024-11-04 09:37:23 -08:00

48 lines
1.8 KiB
Python

import os
import pytest
import pytest_asyncio
from llama_stack.distribution.store import *
from llama_stack.apis.memory_banks import GraphMemoryBankDef, VectorMemoryBankDef
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from llama_stack.distribution.datatypes import * # noqa: F403
@pytest.mark.asyncio
async def test_registry():
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
# delete the file if it exists
if os.path.exists(config.db_path):
os.remove(config.db_path)
registry = DiskRegistry(await kvstore_impl(config))
bank = VectorMemoryBankDef(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
provider_id="bar",
)
model = ModelDefWithProvider(
identifier="test_model",
llama_model="Llama3.2-3B-Instruct",
provider_id="foo",
)
await registry.register(bank)
await registry.register(model)
results = await registry.get("test_bank")
assert len(results) == 1
result_bank = results[0]
assert result_bank.identifier == bank.identifier
assert result_bank.embedding_model == bank.embedding_model
assert result_bank.chunk_size_in_tokens == bank.chunk_size_in_tokens
assert result_bank.overlap_size_in_tokens == bank.overlap_size_in_tokens
assert result_bank.provider_id == bank.provider_id
results = await registry.get("test_model")
assert len(results) == 1
result_model = results[0]
assert result_model.identifier == model.identifier
assert result_model.llama_model == model.llama_model
assert result_model.provider_id == model.provider_id