From 96edf26a3575fd18985aacd1a5d2e3e431c40710 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 15 Jan 2025 15:59:57 -0800 Subject: [PATCH] memory --- tests/client-sdk/memory/test_memory.py | 28 ++++++++++++++++---------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/client-sdk/memory/test_memory.py b/tests/client-sdk/memory/test_memory.py index 998c30125..a5f154fda 100644 --- a/tests/client-sdk/memory/test_memory.py +++ b/tests/client-sdk/memory/test_memory.py @@ -15,7 +15,8 @@ from llama_stack_client.types.memory_insert_params import Document @pytest.fixture(scope="function") def empty_memory_bank_registry(llama_stack_client): memory_banks = [ - memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() + memory_bank.identifier + for memory_bank in llama_stack_client.memory_banks.list().data ] for memory_bank_id in memory_banks: llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id) @@ -35,7 +36,8 @@ def single_entry_memory_bank_registry(llama_stack_client, empty_memory_bank_regi provider_id="faiss", ) memory_banks = [ - memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() + memory_bank.identifier + for memory_bank in llama_stack_client.memory_banks.list().data ] return memory_banks @@ -104,7 +106,8 @@ def test_memory_bank_retrieve(llama_stack_client, empty_memory_bank_registry): def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry): memory_banks_after_register = [ - memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() + memory_bank.identifier + for memory_bank in llama_stack_client.memory_banks.list().data ] assert len(memory_banks_after_register) == 0 @@ -124,14 +127,16 @@ def test_memory_bank_register(llama_stack_client, empty_memory_bank_registry): ) memory_banks_after_register = [ - memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() + memory_bank.identifier + for memory_bank in llama_stack_client.memory_banks.list().data ] assert memory_banks_after_register == [memory_bank_id] def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry): memory_banks = [ - memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() + memory_bank.identifier + for memory_bank in llama_stack_client.memory_banks.list().data ] assert len(memory_banks) == 1 @@ -139,7 +144,8 @@ def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_reg llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id) memory_banks = [ - memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() + memory_bank.identifier + for memory_bank in llama_stack_client.memory_banks.list().data ] assert len(memory_banks) == 0 @@ -195,11 +201,10 @@ def test_memory_bank_insert_inline_and_query( def test_memory_bank_insert_from_url_and_query( llama_stack_client, empty_memory_bank_registry ): - providers = llama_stack_client.providers.list() - assert "memory" in providers - assert len(providers["memory"]) > 0 + providers = llama_stack_client.providers.list().memory + assert len(providers) > 0 - memory_provider_id = providers["memory"][0].provider_id + memory_provider_id = providers[0]["provider_id"] memory_bank_id = "test_bank" llama_stack_client.memory_banks.register( @@ -215,7 +220,8 @@ def test_memory_bank_insert_from_url_and_query( # list to check memory bank is successfully registered available_memory_banks = [ - memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list() + memory_bank.identifier + for memory_bank in llama_stack_client.memory_banks.list().data ] assert memory_bank_id in available_memory_banks