This commit is contained in:
Xi Yan 2025-01-15 15:59:57 -08:00
parent 7fdc9e04ac
commit 96edf26a35

View file

@ -15,7 +15,8 @@ from llama_stack_client.types.memory_insert_params import Document
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def empty_memory_bank_registry(llama_stack_client): def empty_memory_bank_registry(llama_stack_client):
memory_banks = [ memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
for memory_bank_id in memory_banks: for memory_bank_id in memory_banks:
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id) llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
@ -35,7 +36,8 @@ def single_entry_memory_bank_registry(llama_stack_client, empty_memory_bank_regi
provider_id="faiss", provider_id="faiss",
) )
memory_banks = [ memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
return memory_banks return memory_banks
@ -104,7 +106,8 @@ def test_memory_bank_retrieve(llama_stack_client, empty_memory_bank_registry):
def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry): def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry):
memory_banks_after_register = [ memory_banks_after_register = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
assert len(memory_banks_after_register) == 0 assert len(memory_banks_after_register) == 0
@ -124,14 +127,16 @@ def test_memory_bank_register(llama_stack_client, empty_memory_bank_registry):
) )
memory_banks_after_register = [ memory_banks_after_register = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
assert memory_banks_after_register == [memory_bank_id] assert memory_banks_after_register == [memory_bank_id]
def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry): def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry):
memory_banks = [ memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
assert len(memory_banks) == 1 assert len(memory_banks) == 1
@ -139,7 +144,8 @@ def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_reg
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id) llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
memory_banks = [ memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
assert len(memory_banks) == 0 assert len(memory_banks) == 0
@ -195,11 +201,10 @@ def test_memory_bank_insert_inline_and_query(
def test_memory_bank_insert_from_url_and_query( def test_memory_bank_insert_from_url_and_query(
llama_stack_client, empty_memory_bank_registry llama_stack_client, empty_memory_bank_registry
): ):
providers = llama_stack_client.providers.list() providers = llama_stack_client.providers.list().memory
assert "memory" in providers assert len(providers) > 0
assert len(providers["memory"]) > 0
memory_provider_id = providers["memory"][0].provider_id memory_provider_id = providers[0]["provider_id"]
memory_bank_id = "test_bank" memory_bank_id = "test_bank"
llama_stack_client.memory_banks.register( llama_stack_client.memory_banks.register(
@ -215,7 +220,8 @@ def test_memory_bank_insert_from_url_and_query(
# list to check memory bank is successfully registered # list to check memory bank is successfully registered
available_memory_banks = [ available_memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
] ]
assert memory_bank_id in available_memory_banks assert memory_bank_id in available_memory_banks