forked from phoenix-oss/llama-stack-mirror
unregister for memory banks and remove update API (#458)
The semantics of an Update on resources is very tricky to reason about especially for memory banks and models. The best way to go forward here is for the user to unregister and register a new resource. We don't have a compelling reason to support update APIs. Tests: pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "chroma" --env CHROMA_HOST=localhost --env CHROMA_PORT=8000 pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "pgvector" --env PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 $CONDA_PREFIX/bin/pytest -v -s -m "ollama" llama_stack/providers/tests/inference/test_model_registration.py --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
2eab3b7ed9
commit
0850ad656a
18 changed files with 286 additions and 250 deletions
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
@ -43,9 +45,10 @@ def sample_documents():
|
|||
]
|
||||
|
||||
|
||||
async def register_memory_bank(banks_impl: MemoryBanks):
|
||||
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||
return await banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank",
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
|
@ -57,43 +60,70 @@ async def register_memory_bank(banks_impl: MemoryBanks):
|
|||
class TestMemory:
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_list(self, memory_stack):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
# Register a test bank
|
||||
registered_bank = await register_memory_bank(banks_impl)
|
||||
|
||||
try:
|
||||
# Verify our bank shows up in list
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert any(
|
||||
bank.memory_bank_id == registered_bank.memory_bank_id
|
||||
for bank in response
|
||||
)
|
||||
finally:
|
||||
# Clean up
|
||||
await banks_impl.unregister_memory_bank(registered_bank.memory_bank_id)
|
||||
|
||||
# Verify our bank was removed
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
assert all(
|
||||
bank.memory_bank_id != registered_bank.memory_bank_id for bank in response
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_register(self, memory_stack):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank_no_provider",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||
|
||||
# register same memory bank with same id again will fail
|
||||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank_no_provider",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
try:
|
||||
# Register initial bank
|
||||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
|
||||
# Verify our bank exists
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert any(bank.memory_bank_id == bank_id for bank in response)
|
||||
|
||||
# Try registering same bank again
|
||||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
|
||||
# Verify still only one instance of our bank
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert (
|
||||
len([bank for bank in response if bank.memory_bank_id == bank_id]) == 1
|
||||
)
|
||||
finally:
|
||||
# Clean up
|
||||
await banks_impl.unregister_memory_bank(bank_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(self, memory_stack, sample_documents):
|
||||
|
@ -102,17 +132,23 @@ class TestMemory:
|
|||
with pytest.raises(ValueError):
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
||||
await register_memory_bank(banks_impl)
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
registered_bank = await register_memory_bank(banks_impl)
|
||||
await memory_impl.insert_documents(
|
||||
registered_bank.memory_bank_id, sample_documents
|
||||
)
|
||||
|
||||
query1 = "programming language"
|
||||
response1 = await memory_impl.query_documents("test_bank", query1)
|
||||
response1 = await memory_impl.query_documents(
|
||||
registered_bank.memory_bank_id, query1
|
||||
)
|
||||
assert_valid_response(response1)
|
||||
assert any("Python" in chunk.content for chunk in response1.chunks)
|
||||
|
||||
# Test case 3: Query with semantic similarity
|
||||
query3 = "AI and brain-inspired computing"
|
||||
response3 = await memory_impl.query_documents("test_bank", query3)
|
||||
response3 = await memory_impl.query_documents(
|
||||
registered_bank.memory_bank_id, query3
|
||||
)
|
||||
assert_valid_response(response3)
|
||||
assert any(
|
||||
"neural networks" in chunk.content.lower() for chunk in response3.chunks
|
||||
|
@ -121,14 +157,18 @@ class TestMemory:
|
|||
# Test case 4: Query with limit on number of results
|
||||
query4 = "computer"
|
||||
params4 = {"max_chunks": 2}
|
||||
response4 = await memory_impl.query_documents("test_bank", query4, params4)
|
||||
response4 = await memory_impl.query_documents(
|
||||
registered_bank.memory_bank_id, query4, params4
|
||||
)
|
||||
assert_valid_response(response4)
|
||||
assert len(response4.chunks) <= 2
|
||||
|
||||
# Test case 5: Query with threshold on similarity score
|
||||
query5 = "quantum computing" # Not directly related to any document
|
||||
params5 = {"score_threshold": 0.2}
|
||||
response5 = await memory_impl.query_documents("test_bank", query5, params5)
|
||||
response5 = await memory_impl.query_documents(
|
||||
registered_bank.memory_bank_id, query5, params5
|
||||
)
|
||||
assert_valid_response(response5)
|
||||
print("The scores are:", response5.scores)
|
||||
assert all(score >= 0.2 for score in response5.scores)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue