llama-stack/tests/client-sdk/memory/test_memory.py
2024-12-26 15:48:28 -08:00

71 lines
2.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack_client.types.memory_insert_params import Document
def test_memory_bank(llama_stack_client):
providers = llama_stack_client.providers.list()
if "memory" not in providers:
pytest.skip("No memory provider available")
# get memory provider id
assert len(providers["memory"]) > 0
memory_provider_id = providers["memory"][0].provider_id
memory_bank_id = "test_bank"
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
provider_id=memory_provider_id,
)
# list to check memory bank is successfully registered
available_memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert memory_bank_id in available_memory_banks
# add documents to memory bank
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
llama_stack_client.memory.insert(
bank_id=memory_bank_id,
documents=documents,
)
# query documents
response = llama_stack_client.memory.query(
bank_id=memory_bank_id,
query="How do I use lora",
)
assert len(response.chunks) > 0
assert len(response.chunks) == len(response.scores)
contents = [chunk.content for chunk in response.chunks]
assert "lora" in contents[0].lower()