datasetio test

This commit is contained in:
Xi Yan 2024-11-06 16:00:38 -08:00
parent 1b7e19d5d0
commit 1fe4099bd0
4 changed files with 160 additions and 57 deletions

View file

@ -4,10 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import os
from llama_stack.apis.memory import * # noqa: F403
import pytest
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import base64
import mimetypes
from pathlib import Path
# How to run this test:
#
@ -16,72 +21,88 @@ from llama_stack.distribution.datatypes import * # noqa: F403
# -v -s --tb=short --disable-warnings
@pytest.fixture
def sample_documents():
return [
MemoryBankDocument(
document_id="doc1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
MemoryBankDocument(
document_id="doc2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
MemoryBankDocument(
document_id="doc3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
MemoryBankDocument(
document_id="doc4",
content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"},
),
]
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
async def register_memory_bank(banks_impl: MemoryBanks):
bank = VectorMemoryBankDef(
identifier="test_bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
async def register_dataset(
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
if for_generation:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"chat_completion_input": ChatCompletionInputType(),
}
else:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"generated_answer": StringType(),
}
dataset = DatasetDefWithProvider(
identifier=dataset_id,
provider_id="",
url=URL(
uri=test_url,
),
dataset_schema=dataset_schema,
)
await banks_impl.register_memory_bank(bank)
await datasets_impl.register_dataset(dataset)
class TestMemory:
class TestDatasetIO:
@pytest.mark.asyncio
async def test_banks_list(self, memory_stack):
async def test_datasets_list(self, datasetio_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, banks_impl = memory_stack
response = await banks_impl.list_memory_banks()
_, datasets_impl = datasetio_stack
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 0
@pytest.mark.asyncio
async def test_banks_register(self, memory_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, banks_impl = memory_stack
bank = VectorMemoryBankDef(
identifier="test_bank_no_provider",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
async def test_register_dataset(self, datasetio_stack):
_, datasets_impl = datasetio_stack
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 1
assert response[0].identifier == "test_dataset"
@pytest.mark.asyncio
async def test_get_rows_paginated(self, datasetio_stack):
datasetio_impl, datasets_impl = datasetio_stack
await register_dataset(datasets_impl)
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 3
assert response.next_page_token == "3"
await banks_impl.register_memory_bank(bank)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1
# register same memory bank with same id again will fail
await banks_impl.register_memory_bank(bank)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1
# iterate over all rows
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=2,
page_token=response.next_page_token,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 2
assert response.next_page_token == "5"