fix tests

This commit is contained in:
Dinesh Yeduguru 2024-11-08 17:00:15 -08:00
parent 24b914f0fe
commit 5cdcdbe074
3 changed files with 37 additions and 20 deletions

View file

@ -8,6 +8,7 @@ import pytest
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks.memory_banks import VectorMemoryBankParams
# How to run this test:
#
@ -44,13 +45,14 @@ def sample_documents():
async def register_memory_bank(banks_impl: MemoryBanks):
await banks_impl.register_memory_bank(
VectorRegistration(
memory_bank_id="test_bank",
return await banks_impl.register_memory_bank(
memory_bank_id="test_bank",
memory_bank_type="vector",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
),
)
@ -69,20 +71,30 @@ class TestMemory:
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, banks_impl = memory_stack
bank = VectorRegistration(
memory_bank_id="test_bank_no_provider",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await banks_impl.register_memory_bank(bank)
bank = await banks_impl.register_memory_bank(
memory_bank_id="test_bank_no_provider",
memory_bank_type="vector",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1
# register same memory bank with same id again will fail
await banks_impl.register_memory_bank(bank)
await banks_impl.register_memory_bank(
memory_bank_id="test_bank_no_provider",
memory_bank_type="vector",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1