mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
datasetio test
This commit is contained in:
parent
1b7e19d5d0
commit
1fe4099bd0
4 changed files with 160 additions and 57 deletions
|
@ -149,4 +149,5 @@ pytest_plugins = [
|
||||||
"llama_stack.providers.tests.safety.fixtures",
|
"llama_stack.providers.tests.safety.fixtures",
|
||||||
"llama_stack.providers.tests.memory.fixtures",
|
"llama_stack.providers.tests.memory.fixtures",
|
||||||
"llama_stack.providers.tests.agents.fixtures",
|
"llama_stack.providers.tests.agents.fixtures",
|
||||||
|
"llama_stack.providers.tests.datasetio.fixtures",
|
||||||
]
|
]
|
||||||
|
|
29
llama_stack/providers/tests/datasetio/conftest.py
Normal file
29
llama_stack/providers/tests/datasetio/conftest.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .fixtures import DATASETIO_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
for fixture_name in DATASETIO_FIXTURES:
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers",
|
||||||
|
f"{fixture_name}: marks tests as {fixture_name} specific",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "datasetio_stack" in metafunc.fixturenames:
|
||||||
|
metafunc.parametrize(
|
||||||
|
"datasetio_stack",
|
||||||
|
[
|
||||||
|
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||||
|
for fixture_name in DATASETIO_FIXTURES
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
52
llama_stack/providers/tests/datasetio/fixtures.py
Normal file
52
llama_stack/providers/tests/datasetio/fixtures.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def datasetio_remote() -> ProviderFixture:
|
||||||
|
return remote_stack_fixture()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def datasetio_meta_reference() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="meta-reference",
|
||||||
|
provider_type="meta-reference",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DATASETIO_FIXTURES = ["meta_reference", "remote"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def datasetio_stack(request):
|
||||||
|
fixture_name = request.param
|
||||||
|
fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
|
||||||
|
|
||||||
|
impls = await resolve_impls_for_test_v2(
|
||||||
|
[Api.datasetio],
|
||||||
|
{"datasetio": fixture.providers},
|
||||||
|
fixture.provider_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return impls[Api.datasetio], impls[Api.datasets]
|
|
@ -4,10 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import pytest
|
import os
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
import pytest
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
|
@ -16,72 +21,88 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
# -v -s --tb=short --disable-warnings
|
# -v -s --tb=short --disable-warnings
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def data_url_from_file(file_path: str) -> str:
|
||||||
def sample_documents():
|
if not os.path.exists(file_path):
|
||||||
return [
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
MemoryBankDocument(
|
|
||||||
document_id="doc1",
|
with open(file_path, "rb") as file:
|
||||||
content="Python is a high-level programming language.",
|
file_content = file.read()
|
||||||
metadata={"category": "programming", "difficulty": "beginner"},
|
|
||||||
),
|
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||||
MemoryBankDocument(
|
mime_type, _ = mimetypes.guess_type(file_path)
|
||||||
document_id="doc2",
|
|
||||||
content="Machine learning is a subset of artificial intelligence.",
|
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||||
metadata={"category": "AI", "difficulty": "advanced"},
|
|
||||||
),
|
return data_url
|
||||||
MemoryBankDocument(
|
|
||||||
document_id="doc3",
|
|
||||||
content="Data structures are fundamental to computer science.",
|
|
||||||
metadata={"category": "computer science", "difficulty": "intermediate"},
|
|
||||||
),
|
|
||||||
MemoryBankDocument(
|
|
||||||
document_id="doc4",
|
|
||||||
content="Neural networks are inspired by biological neural networks.",
|
|
||||||
metadata={"category": "AI", "difficulty": "advanced"},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def register_memory_bank(banks_impl: MemoryBanks):
|
async def register_dataset(
|
||||||
bank = VectorMemoryBankDef(
|
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
|
||||||
identifier="test_bank",
|
):
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||||
chunk_size_in_tokens=512,
|
test_url = data_url_from_file(str(test_file))
|
||||||
overlap_size_in_tokens=64,
|
|
||||||
|
if for_generation:
|
||||||
|
dataset_schema = {
|
||||||
|
"expected_answer": StringType(),
|
||||||
|
"input_query": StringType(),
|
||||||
|
"chat_completion_input": ChatCompletionInputType(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
dataset_schema = {
|
||||||
|
"expected_answer": StringType(),
|
||||||
|
"input_query": StringType(),
|
||||||
|
"generated_answer": StringType(),
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset = DatasetDefWithProvider(
|
||||||
|
identifier=dataset_id,
|
||||||
|
provider_id="",
|
||||||
|
url=URL(
|
||||||
|
uri=test_url,
|
||||||
|
),
|
||||||
|
dataset_schema=dataset_schema,
|
||||||
)
|
)
|
||||||
|
await datasets_impl.register_dataset(dataset)
|
||||||
await banks_impl.register_memory_bank(bank)
|
|
||||||
|
|
||||||
|
|
||||||
class TestMemory:
|
class TestDatasetIO:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(self, memory_stack):
|
async def test_datasets_list(self, datasetio_stack):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
_, banks_impl = memory_stack
|
_, datasets_impl = datasetio_stack
|
||||||
response = await banks_impl.list_memory_banks()
|
response = await datasets_impl.list_datasets()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
assert len(response) == 0
|
assert len(response) == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_register(self, memory_stack):
|
async def test_register_dataset(self, datasetio_stack):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
_, datasets_impl = datasetio_stack
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
await register_dataset(datasets_impl)
|
||||||
_, banks_impl = memory_stack
|
response = await datasets_impl.list_datasets()
|
||||||
bank = VectorMemoryBankDef(
|
assert isinstance(response, list)
|
||||||
identifier="test_bank_no_provider",
|
assert len(response) == 1
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
assert response[0].identifier == "test_dataset"
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
overlap_size_in_tokens=64,
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_rows_paginated(self, datasetio_stack):
|
||||||
|
datasetio_impl, datasets_impl = datasetio_stack
|
||||||
|
await register_dataset(datasets_impl)
|
||||||
|
response = await datasetio_impl.get_rows_paginated(
|
||||||
|
dataset_id="test_dataset",
|
||||||
|
rows_in_page=3,
|
||||||
)
|
)
|
||||||
|
assert isinstance(response.rows, list)
|
||||||
|
assert len(response.rows) == 3
|
||||||
|
assert response.next_page_token == "3"
|
||||||
|
|
||||||
await banks_impl.register_memory_bank(bank)
|
# iterate over all rows
|
||||||
response = await banks_impl.list_memory_banks()
|
response = await datasetio_impl.get_rows_paginated(
|
||||||
assert isinstance(response, list)
|
dataset_id="test_dataset",
|
||||||
assert len(response) == 1
|
rows_in_page=2,
|
||||||
|
page_token=response.next_page_token,
|
||||||
# register same memory bank with same id again will fail
|
)
|
||||||
await banks_impl.register_memory_bank(bank)
|
assert isinstance(response.rows, list)
|
||||||
response = await banks_impl.list_memory_banks()
|
assert len(response.rows) == 2
|
||||||
assert isinstance(response, list)
|
assert response.next_page_token == "5"
|
||||||
assert len(response) == 1
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue