diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 11b0dcb45..9b27ed94a 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -149,4 +149,5 @@ pytest_plugins = [ "llama_stack.providers.tests.safety.fixtures", "llama_stack.providers.tests.memory.fixtures", "llama_stack.providers.tests.agents.fixtures", + "llama_stack.providers.tests.datasetio.fixtures", ] diff --git a/llama_stack/providers/tests/datasetio/conftest.py b/llama_stack/providers/tests/datasetio/conftest.py new file mode 100644 index 000000000..740eddb33 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/conftest.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from .fixtures import DATASETIO_FIXTURES + + +def pytest_configure(config): + for fixture_name in DATASETIO_FIXTURES: + config.addinivalue_line( + "markers", + f"{fixture_name}: marks tests as {fixture_name} specific", + ) + + +def pytest_generate_tests(metafunc): + if "datasetio_stack" in metafunc.fixturenames: + metafunc.parametrize( + "datasetio_stack", + [ + pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) + for fixture_name in DATASETIO_FIXTURES + ], + indirect=True, + ) diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py new file mode 100644 index 000000000..3a0ee66e6 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/fixtures.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import tempfile + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..conftest import ProviderFixture, remote_stack_fixture +from ..env import get_env_or_fail + + +@pytest.fixture(scope="session") +def datasetio_remote() -> ProviderFixture: + return remote_stack_fixture() + + +@pytest.fixture(scope="session") +def datasetio_meta_reference() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config={}, + ) + ], + ) + + +DATASETIO_FIXTURES = ["meta_reference", "remote"] + + +@pytest_asyncio.fixture(scope="session") +async def datasetio_stack(request): + fixture_name = request.param + fixture = request.getfixturevalue(f"datasetio_{fixture_name}") + + impls = await resolve_impls_for_test_v2( + [Api.datasetio], + {"datasetio": fixture.providers}, + fixture.provider_data, + ) + + return impls[Api.datasetio], impls[Api.datasets] diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 4cc46e619..bdaf4bcbf 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -4,10 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import pytest +import os -from llama_stack.apis.memory import * # noqa: F403 +import pytest +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +import base64 +import mimetypes +from pathlib import Path # How to run this test: # @@ -16,72 +21,88 @@ from llama_stack.distribution.datatypes import * # noqa: F403 # -v -s --tb=short --disable-warnings -@pytest.fixture -def sample_documents(): - return [ - MemoryBankDocument( - document_id="doc1", - content="Python is a high-level programming language.", - metadata={"category": "programming", "difficulty": "beginner"}, - ), - MemoryBankDocument( - document_id="doc2", - content="Machine learning is a subset of artificial intelligence.", - metadata={"category": "AI", "difficulty": "advanced"}, - ), - MemoryBankDocument( - document_id="doc3", - content="Data structures are fundamental to computer science.", - metadata={"category": "computer science", "difficulty": "intermediate"}, - ), - MemoryBankDocument( - document_id="doc4", - content="Neural networks are inspired by biological neural networks.", - metadata={"category": "AI", "difficulty": "advanced"}, - ), - ] +def data_url_from_file(file_path: str) -> str: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url -async def register_memory_bank(banks_impl: MemoryBanks): - bank = VectorMemoryBankDef( - identifier="test_bank", - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, +async def register_dataset( + datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset" +): + test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" + test_url = data_url_from_file(str(test_file)) + + if for_generation: + dataset_schema = { + "expected_answer": StringType(), + "input_query": StringType(), + "chat_completion_input": ChatCompletionInputType(), + } + else: + dataset_schema = { + "expected_answer": StringType(), + "input_query": StringType(), + "generated_answer": StringType(), + } + + dataset = DatasetDefWithProvider( + identifier=dataset_id, + provider_id="", + url=URL( + uri=test_url, + ), + dataset_schema=dataset_schema, ) - - await banks_impl.register_memory_bank(bank) + await datasets_impl.register_dataset(dataset) -class TestMemory: +class TestDatasetIO: @pytest.mark.asyncio - async def test_banks_list(self, memory_stack): + async def test_datasets_list(self, datasetio_stack): # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful - _, banks_impl = memory_stack - response = await banks_impl.list_memory_banks() + _, datasets_impl = datasetio_stack + response = await datasets_impl.list_datasets() assert isinstance(response, list) assert len(response) == 0 @pytest.mark.asyncio - async def test_banks_register(self, memory_stack): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - _, banks_impl = memory_stack - bank = VectorMemoryBankDef( - identifier="test_bank_no_provider", - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, + async def test_register_dataset(self, datasetio_stack): + _, datasets_impl = datasetio_stack + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 1 + assert response[0].identifier == "test_dataset" + + @pytest.mark.asyncio + async def test_get_rows_paginated(self, datasetio_stack): + datasetio_impl, datasets_impl = datasetio_stack + await register_dataset(datasets_impl) + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, ) + assert isinstance(response.rows, list) + assert len(response.rows) == 3 + assert response.next_page_token == "3" - await banks_impl.register_memory_bank(bank) - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert len(response) == 1 - - # register same memory bank with same id again will fail - await banks_impl.register_memory_bank(bank) - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert len(response) == 1 + # iterate over all rows + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=2, + page_token=response.next_page_token, + ) + assert isinstance(response.rows, list) + assert len(response.rows) == 2 + assert response.next_page_token == "5"