Another round of simplification and clarity for models/shields/memory_banks stuff

This commit is contained in:
Ashwin Bharambe 2024-10-09 19:19:26 -07:00
parent 73a0a34e39
commit b55034c0de
27 changed files with 454 additions and 444 deletions

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
@ -30,12 +31,14 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test
@pytest_asyncio.fixture(scope="session")
async def memory_impl():
async def memory_settings():
impls = await resolve_impls_for_test(
Api.memory,
memory_banks=[],
)
return impls[Api.memory]
return {
"memory_impl": impls[Api.memory],
"memory_banks_impl": impls[Api.memory_banks],
}
@pytest.fixture
@ -64,23 +67,35 @@ def sample_documents():
]
async def register_memory_bank(memory_impl: Memory):
async def register_memory_bank(banks_impl: MemoryBanks):
bank = VectorMemoryBankDef(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
provider_id=os.environ["PROVIDER_ID"],
)
await memory_impl.register_memory_bank(bank)
await banks_impl.register_memory_bank(bank)
@pytest.mark.asyncio
async def test_query_documents(memory_impl, sample_documents):
async def test_banks_list(memory_settings):
banks_impl = memory_settings["memory_banks_impl"]
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 0
@pytest.mark.asyncio
async def test_query_documents(memory_settings, sample_documents):
memory_impl = memory_settings["memory_impl"]
banks_impl = memory_settings["memory_banks_impl"]
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
await register_memory_bank(memory_impl)
await register_memory_bank(banks_impl)
await memory_impl.insert_documents("test_bank", sample_documents)
query1 = "programming language"