diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 26fb97072..8b5078151 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -10,7 +10,7 @@ import pandas from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate -from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url +from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri from llama_stack.providers.utils.kvstore import kvstore_impl from .config import LocalFSDatasetIOConfig @@ -40,11 +40,13 @@ class PandasDataframeDataset: return if self.dataset_def.source.type == "uri": - self.df = get_dataframe_from_url(self.dataset_def.uri) + self.df = get_dataframe_from_uri(self.dataset_def.source.uri) elif self.dataset_def.source.type == "rows": self.df = pandas.DataFrame(self.dataset_def.source.rows) else: - raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}") + raise ValueError( + f"Unsupported dataset source type: {self.dataset_def.source.type}" + ) if self.df is None: raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") @@ -117,4 +119,6 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): dataset_impl.load() new_rows_df = pandas.DataFrame(rows) - dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) + dataset_impl.df = pandas.concat( + [dataset_impl.df, new_rows_df], ignore_index=True + ) diff --git a/llama_stack/providers/utils/datasetio/url_utils.py b/llama_stack/providers/utils/datasetio/url_utils.py index f54cb55eb..4ca0309a5 100644 --- a/llama_stack/providers/utils/datasetio/url_utils.py +++ b/llama_stack/providers/utils/datasetio/url_utils.py @@ -14,14 +14,14 @@ from llama_stack.apis.common.content_types import URL from llama_stack.providers.utils.memory.vector_store import parse_data_url -def get_dataframe_from_url(url: URL): +def get_dataframe_from_uri(uri: str): df = None - if url.uri.endswith(".csv"): - df = pandas.read_csv(url.uri) - elif url.uri.endswith(".xlsx"): - df = pandas.read_excel(url.uri) - elif url.uri.startswith("data:"): - parts = parse_data_url(url.uri) + if uri.endswith(".csv"): + df = pandas.read_csv(uri) + elif uri.endswith(".xlsx"): + df = pandas.read_excel(uri) + elif uri.startswith("data:"): + parts = parse_data_url(uri) data = parts["data"] if parts["is_base64"]: data = base64.b64decode(data) @@ -39,6 +39,6 @@ def get_dataframe_from_url(url: URL): else: df = pandas.read_excel(data_bytes) else: - raise ValueError(f"Unsupported file type: {url}") + raise ValueError(f"Unsupported file type: {uri}") return df diff --git a/tests/integration/datasetio/test_datasetio.py b/tests/integration/datasetio/test_datasetio.py deleted file mode 100644 index 459589e7b..000000000 --- a/tests/integration/datasetio/test_datasetio.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import base64 -import mimetypes -import os -from pathlib import Path - -import pytest - -# How to run this test: -# -# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasetio - - -@pytest.fixture -def dataset_for_test(llama_stack_client): - dataset_id = "test_dataset" - register_dataset(llama_stack_client, dataset_id=dataset_id) - yield - # Teardown - this always runs, even if the test fails - try: - llama_stack_client.datasets.unregister(dataset_id) - except Exception as e: - print(f"Warning: Failed to unregister test_dataset: {e}") - - -def data_url_from_file(file_path: str) -> str: - if not os.path.exists(file_path): - raise FileNotFoundError(f"File not found: {file_path}") - - with open(file_path, "rb") as file: - file_content = file.read() - - base64_content = base64.b64encode(file_content).decode("utf-8") - mime_type, _ = mimetypes.guess_type(file_path) - - data_url = f"data:{mime_type};base64,{base64_content}" - - return data_url - - -def register_dataset(llama_stack_client, for_generation=False, for_rag=False, dataset_id="test_dataset"): - if for_rag: - test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv" - else: - test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" - test_url = data_url_from_file(str(test_file)) - - if for_generation: - dataset_schema = { - "expected_answer": {"type": "string"}, - "input_query": {"type": "string"}, - "chat_completion_input": {"type": "chat_completion_input"}, - } - elif for_rag: - dataset_schema = { - "expected_answer": {"type": "string"}, - "input_query": {"type": "string"}, - "generated_answer": {"type": "string"}, - "context": {"type": "string"}, - } - else: - dataset_schema = { - "expected_answer": {"type": "string"}, - "input_query": {"type": "string"}, - "generated_answer": {"type": "string"}, - } - - dataset_providers = [x for x in llama_stack_client.providers.list() if x.api == "datasetio"] - dataset_provider_id = dataset_providers[0].provider_id - - llama_stack_client.datasets.register( - dataset_id=dataset_id, - dataset_schema=dataset_schema, - url=dict(uri=test_url), - provider_id=dataset_provider_id, - ) - - -def test_register_unregister_dataset(llama_stack_client): - register_dataset(llama_stack_client) - response = llama_stack_client.datasets.list() - assert isinstance(response, list) - assert len(response) == 1 - assert response[0].identifier == "test_dataset" - - llama_stack_client.datasets.unregister("test_dataset") - response = llama_stack_client.datasets.list() - assert isinstance(response, list) - assert len(response) == 0 - - -def test_get_rows_paginated(llama_stack_client, dataset_for_test): - response = llama_stack_client.datasetio.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=3, - ) - assert isinstance(response.rows, list) - assert len(response.rows) == 3 - assert response.next_page_token == "3" - - # iterate over all rows - response = llama_stack_client.datasetio.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=2, - page_token=response.next_page_token, - ) - assert isinstance(response.rows, list) - assert len(response.rows) == 2 - assert response.next_page_token == "5" diff --git a/tests/integration/datasetio/__init__.py b/tests/integration/datasets/__init__.py similarity index 100% rename from tests/integration/datasetio/__init__.py rename to tests/integration/datasets/__init__.py diff --git a/tests/integration/datasetio/test_dataset.csv b/tests/integration/datasets/test_dataset.csv similarity index 100% rename from tests/integration/datasetio/test_dataset.csv rename to tests/integration/datasets/test_dataset.csv diff --git a/tests/integration/datasets/test_datasets.py b/tests/integration/datasets/test_datasets.py index dea0ac2aa..b24af5984 100644 --- a/tests/integration/datasets/test_datasets.py +++ b/tests/integration/datasets/test_datasets.py @@ -5,6 +5,10 @@ # the root directory of this source tree. +import base64 +import mimetypes +import os + import pytest # How to run this test: @@ -12,6 +16,21 @@ import pytest # LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets +def data_url_from_file(file_path: str) -> str: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url + + @pytest.mark.parametrize( "purpose, source, provider_id, limit", [ @@ -67,3 +86,19 @@ def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, llama_stack_client.datasets.unregister(dataset.identifier) dataset_list = llama_stack_client.datasets.list() assert dataset.identifier not in [d.identifier for d in dataset_list] + + +def test_register_and_iterrows_from_base64_data(llama_stack_client): + dataset = llama_stack_client.datasets.register( + purpose="eval/messages-answer", + source={ + "type": "uri", + "uri": data_url_from_file( + os.path.join(os.path.dirname(__file__), "test_dataset.csv") + ), + }, + ) + assert dataset.identifier is not None + assert dataset.provider_id == "localfs" + iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier) + assert len(iterrow_response.data) == 5 diff --git a/tests/integration/datasetio/test_rag_dataset.csv b/tests/integration/datasets/test_rag_dataset.csv similarity index 100% rename from tests/integration/datasetio/test_rag_dataset.csv rename to tests/integration/datasets/test_rag_dataset.csv