From a568bf3f9deecbea847c7fdcde362e95aede08d9 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 16:48:09 -0700 Subject: [PATCH] feat(dataset api): (1.5/n) fix dataset registeration (#1659) # What does this PR do? - fix dataset registeration & iterrows > NOTE: the URL endpoint is changed to datasetio due to flaky path routing [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/datasets/test_datasets.py ``` image [//]: # (## Documentation) --- docs/_static/llama-stack-spec.html | 12 +- docs/_static/llama-stack-spec.yaml | 12 +- docs/openapi_generator/pyopenapi/generator.py | 4 +- llama_stack/apis/datasetio/datasetio.py | 5 +- .../distribution/routers/routing_tables.py | 4 +- .../inline/datasetio/localfs/datasetio.py | 112 ++++------------- .../datasetio/huggingface/huggingface.py | 32 ++--- .../providers/utils/datasetio/url_utils.py | 17 ++- tests/integration/datasetio/test_datasetio.py | 114 ------------------ .../{datasetio => datasets}/__init__.py | 0 .../{datasetio => datasets}/test_dataset.csv | 0 tests/integration/datasets/test_datasets.py | 95 +++++++++++++++ .../test_rag_dataset.csv | 0 13 files changed, 159 insertions(+), 248 deletions(-) delete mode 100644 tests/integration/datasetio/test_datasetio.py rename tests/integration/{datasetio => datasets}/__init__.py (100%) rename tests/integration/{datasetio => datasets}/test_dataset.csv (100%) create mode 100644 tests/integration/datasets/test_datasets.py rename tests/integration/{datasetio => datasets}/test_rag_dataset.csv (100%) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 462034c3d..e3c81ddb9 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -40,7 +40,7 @@ } ], "paths": { - "/v1/datasets/{dataset_id}/append-rows": { + "/v1/datasetio/append-rows/{dataset_id}": { "post": { "responses": { "200": { @@ -60,7 +60,7 @@ } }, "tags": [ - "Datasets" + "DatasetIO" ], "description": "", "parameters": [ @@ -2177,7 +2177,7 @@ } } }, - "/v1/datasets/{dataset_id}/iterrows": { + "/v1/datasetio/iterrows/{dataset_id}": { "get": { "responses": { "200": { @@ -2204,7 +2204,7 @@ } }, "tags": [ - "Datasets" + "DatasetIO" ], "description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.", "parameters": [ @@ -10274,7 +10274,7 @@ "name": "Benchmarks" }, { - "name": "Datasets" + "name": "DatasetIO" }, { "name": "Datasets" @@ -10342,7 +10342,7 @@ "Agents", "BatchInference (Coming Soon)", "Benchmarks", - "Datasets", + "DatasetIO", "Datasets", "Eval", "Files", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 79adc221a..a3d4dbcc9 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10,7 +10,7 @@ info: servers: - url: http://any-hosted-llama-stack.com paths: - /v1/datasets/{dataset_id}/append-rows: + /v1/datasetio/append-rows/{dataset_id}: post: responses: '200': @@ -26,7 +26,7 @@ paths: default: $ref: '#/components/responses/DefaultError' tags: - - Datasets + - DatasetIO description: '' parameters: - name: dataset_id @@ -1457,7 +1457,7 @@ paths: schema: $ref: '#/components/schemas/InvokeToolRequest' required: true - /v1/datasets/{dataset_id}/iterrows: + /v1/datasetio/iterrows/{dataset_id}: get: responses: '200': @@ -1477,7 +1477,7 @@ paths: default: $ref: '#/components/responses/DefaultError' tags: - - Datasets + - DatasetIO description: >- Get a paginated list of rows from a dataset. Uses cursor-based pagination. parameters: @@ -6931,7 +6931,7 @@ tags: Agents API for creating and interacting with agentic systems. - name: BatchInference (Coming Soon) - name: Benchmarks - - name: Datasets + - name: DatasetIO - name: Datasets - name: Eval x-displayName: >- @@ -6971,7 +6971,7 @@ x-tagGroups: - Agents - BatchInference (Coming Soon) - Benchmarks - - Datasets + - DatasetIO - Datasets - Eval - Files diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index a7ee87125..02a4776e4 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -552,8 +552,8 @@ class Generator: print(op.defining_class.__name__) # TODO (xiyan): temporary fix for datasetio inner impl + datasets api - if op.defining_class.__name__ in ["DatasetIO"]: - op.defining_class.__name__ = "Datasets" + # if op.defining_class.__name__ in ["DatasetIO"]: + # op.defining_class.__name__ = "Datasets" doc_string = parse_type(op.func_ref) doc_params = dict( diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index e7073fd29..6079e5b99 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -34,7 +34,8 @@ class DatasetIO(Protocol): # keeping for aligning with inference/safety, but this is not used dataset_store: DatasetStore - @webmethod(route="/datasets/{dataset_id}/iterrows", method="GET") + # TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing + @webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET") async def iterrows( self, dataset_id: str, @@ -49,5 +50,5 @@ class DatasetIO(Protocol): """ ... - @webmethod(route="/datasets/{dataset_id}/append-rows", method="POST") + @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST") async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 589a03b25..533993421 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -371,9 +371,9 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): provider_dataset_id = dataset_id # infer provider from source - if source.type == DatasetType.rows: + if source.type == DatasetType.rows.value: provider_id = "localfs" - elif source.type == DatasetType.uri: + elif source.type == DatasetType.uri.value: # infer provider from uri if source.uri.startswith("huggingface"): provider_id = "huggingface" diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index afa9ee0ff..3b0d01edd 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -3,20 +3,14 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 -import os -from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Any, Dict, List, Optional -from urllib.parse import urlparse import pandas -from llama_stack.apis.common.content_types import URL from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate -from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url +from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri from llama_stack.providers.utils.kvstore import kvstore_impl from .config import LocalFSDatasetIOConfig @@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig DATASETS_PREFIX = "localfs_datasets:" -class BaseDataset(ABC): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - @abstractmethod - def __len__(self) -> int: - raise NotImplementedError() - - @abstractmethod - def __getitem__(self, idx): - raise NotImplementedError() - - @abstractmethod - def load(self): - raise NotImplementedError() - - -@dataclass -class DatasetInfo: - dataset_def: Dataset - dataset_impl: BaseDataset - - -class PandasDataframeDataset(BaseDataset): +class PandasDataframeDataset: def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.dataset_def = dataset_def @@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset): else: return self.df.iloc[idx].to_dict() - def _validate_dataset_schema(self, df) -> pandas.DataFrame: - # note that we will drop any columns in dataset that are not in the schema - df = df[self.dataset_def.dataset_schema.keys()] - # check all columns in dataset schema are present - assert len(df.columns) == len(self.dataset_def.dataset_schema) - # TODO: type checking against column types in dataset schema - return df - def load(self) -> None: if self.df is not None: return - df = get_dataframe_from_url(self.dataset_def.url) - if df is None: - raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") + if self.dataset_def.source.type == "uri": + self.df = get_dataframe_from_uri(self.dataset_def.source.uri) + elif self.dataset_def.source.type == "rows": + self.df = pandas.DataFrame(self.dataset_def.source.rows) + else: + raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}") - self.df = self._validate_dataset_schema(df) + if self.df is None: + raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): @@ -99,29 +66,21 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): for dataset in stored_datasets: dataset = Dataset.model_validate_json(dataset) - dataset_impl = PandasDataframeDataset(dataset) - self.dataset_infos[dataset.identifier] = DatasetInfo( - dataset_def=dataset, - dataset_impl=dataset_impl, - ) + self.dataset_infos[dataset.identifier] = dataset async def shutdown(self) -> None: ... async def register_dataset( self, - dataset: Dataset, + dataset_def: Dataset, ) -> None: # Store in kvstore - key = f"{DATASETS_PREFIX}{dataset.identifier}" + key = f"{DATASETS_PREFIX}{dataset_def.identifier}" await self.kvstore.set( key=key, - value=dataset.json(), - ) - dataset_impl = PandasDataframeDataset(dataset) - self.dataset_infos[dataset.identifier] = DatasetInfo( - dataset_def=dataset, - dataset_impl=dataset_impl, + value=dataset_def.model_dump_json(), ) + self.dataset_infos[dataset_def.identifier] = dataset_def async def unregister_dataset(self, dataset_id: str) -> None: key = f"{DATASETS_PREFIX}{dataset_id}" @@ -134,51 +93,28 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): start_index: Optional[int] = None, limit: Optional[int] = None, ) -> IterrowsResponse: - dataset_info = self.dataset_infos.get(dataset_id) - dataset_info.dataset_impl.load() + dataset_def = self.dataset_infos[dataset_id] + dataset_impl = PandasDataframeDataset(dataset_def) + dataset_impl.load() start_index = start_index or 0 if limit is None or limit == -1: - end = len(dataset_info.dataset_impl) + end = len(dataset_impl) else: - end = min(start_index + limit, len(dataset_info.dataset_impl)) + end = min(start_index + limit, len(dataset_impl)) - rows = dataset_info.dataset_impl[start_index:end] + rows = dataset_impl[start_index:end] return IterrowsResponse( data=rows, - next_index=end if end < len(dataset_info.dataset_impl) else None, + next_index=end if end < len(dataset_impl) else None, ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: - dataset_info = self.dataset_infos.get(dataset_id) - if dataset_info is None: - raise ValueError(f"Dataset with id {dataset_id} not found") - - dataset_impl = dataset_info.dataset_impl + dataset_def = self.dataset_infos[dataset_id] + dataset_impl = PandasDataframeDataset(dataset_def) dataset_impl.load() new_rows_df = pandas.DataFrame(rows) - new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df) dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) - - url = str(dataset_info.dataset_def.url.uri) - parsed_url = urlparse(url) - - if parsed_url.scheme == "file" or not parsed_url.scheme: - file_path = parsed_url.path - os.makedirs(os.path.dirname(file_path), exist_ok=True) - dataset_impl.df.to_csv(file_path, index=False) - elif parsed_url.scheme == "data": - # For data URLs, we need to update the base64-encoded content - if not parsed_url.path.startswith("text/csv;base64,"): - raise ValueError("Data URL must be a base64-encoded CSV") - - csv_buffer = dataset_impl.df.to_csv(index=False) - base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8") - dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}") - else: - raise ValueError( - f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing." - ) diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index d59edda30..41ce747f7 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from typing import Any, Dict, List, Optional +from urllib.parse import parse_qs, urlparse import datasets as hf_datasets from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import Dataset from llama_stack.providers.datatypes import DatasetsProtocolPrivate -from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.kvstore import kvstore_impl from .config import HuggingfaceDatasetIOConfig @@ -18,22 +18,14 @@ from .config import HuggingfaceDatasetIOConfig DATASETS_PREFIX = "datasets:" -def load_hf_dataset(dataset_def: Dataset): - if dataset_def.metadata.get("path", None): - dataset = hf_datasets.load_dataset(**dataset_def.metadata) - else: - df = get_dataframe_from_url(dataset_def.url) +def parse_hf_params(dataset_def: Dataset): + uri = dataset_def.source.uri + parsed_uri = urlparse(uri) + params = parse_qs(parsed_uri.query) + params = {k: v[0] for k, v in params.items()} + path = parsed_uri.path.lstrip("/") - if df is None: - raise ValueError(f"Failed to load dataset from {dataset_def.url}") - - dataset = hf_datasets.Dataset.from_pandas(df) - - # drop columns not specified by schema - if dataset_def.dataset_schema: - dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys())) - - return dataset + return path, params class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): @@ -64,7 +56,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): key = f"{DATASETS_PREFIX}{dataset_def.identifier}" await self.kvstore.set( key=key, - value=dataset_def.json(), + value=dataset_def.model_dump_json(), ) self.dataset_infos[dataset_def.identifier] = dataset_def @@ -80,7 +72,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): limit: Optional[int] = None, ) -> IterrowsResponse: dataset_def = self.dataset_infos[dataset_id] - loaded_dataset = load_hf_dataset(dataset_def) + path, params = parse_hf_params(dataset_def) + loaded_dataset = hf_datasets.load_dataset(path, **params) start_index = start_index or 0 @@ -98,7 +91,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: dataset_def = self.dataset_infos[dataset_id] - loaded_dataset = load_hf_dataset(dataset_def) + path, params = parse_hf_params(dataset_def) + loaded_dataset = hf_datasets.load_dataset(path, **params) # Convert rows to HF Dataset format new_dataset = hf_datasets.Dataset.from_list(rows) diff --git a/llama_stack/providers/utils/datasetio/url_utils.py b/llama_stack/providers/utils/datasetio/url_utils.py index f54cb55eb..6a544ea49 100644 --- a/llama_stack/providers/utils/datasetio/url_utils.py +++ b/llama_stack/providers/utils/datasetio/url_utils.py @@ -10,18 +10,17 @@ from urllib.parse import unquote import pandas -from llama_stack.apis.common.content_types import URL from llama_stack.providers.utils.memory.vector_store import parse_data_url -def get_dataframe_from_url(url: URL): +def get_dataframe_from_uri(uri: str): df = None - if url.uri.endswith(".csv"): - df = pandas.read_csv(url.uri) - elif url.uri.endswith(".xlsx"): - df = pandas.read_excel(url.uri) - elif url.uri.startswith("data:"): - parts = parse_data_url(url.uri) + if uri.endswith(".csv"): + df = pandas.read_csv(uri) + elif uri.endswith(".xlsx"): + df = pandas.read_excel(uri) + elif uri.startswith("data:"): + parts = parse_data_url(uri) data = parts["data"] if parts["is_base64"]: data = base64.b64decode(data) @@ -39,6 +38,6 @@ def get_dataframe_from_url(url: URL): else: df = pandas.read_excel(data_bytes) else: - raise ValueError(f"Unsupported file type: {url}") + raise ValueError(f"Unsupported file type: {uri}") return df diff --git a/tests/integration/datasetio/test_datasetio.py b/tests/integration/datasetio/test_datasetio.py deleted file mode 100644 index 459589e7b..000000000 --- a/tests/integration/datasetio/test_datasetio.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import base64 -import mimetypes -import os -from pathlib import Path - -import pytest - -# How to run this test: -# -# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasetio - - -@pytest.fixture -def dataset_for_test(llama_stack_client): - dataset_id = "test_dataset" - register_dataset(llama_stack_client, dataset_id=dataset_id) - yield - # Teardown - this always runs, even if the test fails - try: - llama_stack_client.datasets.unregister(dataset_id) - except Exception as e: - print(f"Warning: Failed to unregister test_dataset: {e}") - - -def data_url_from_file(file_path: str) -> str: - if not os.path.exists(file_path): - raise FileNotFoundError(f"File not found: {file_path}") - - with open(file_path, "rb") as file: - file_content = file.read() - - base64_content = base64.b64encode(file_content).decode("utf-8") - mime_type, _ = mimetypes.guess_type(file_path) - - data_url = f"data:{mime_type};base64,{base64_content}" - - return data_url - - -def register_dataset(llama_stack_client, for_generation=False, for_rag=False, dataset_id="test_dataset"): - if for_rag: - test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv" - else: - test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" - test_url = data_url_from_file(str(test_file)) - - if for_generation: - dataset_schema = { - "expected_answer": {"type": "string"}, - "input_query": {"type": "string"}, - "chat_completion_input": {"type": "chat_completion_input"}, - } - elif for_rag: - dataset_schema = { - "expected_answer": {"type": "string"}, - "input_query": {"type": "string"}, - "generated_answer": {"type": "string"}, - "context": {"type": "string"}, - } - else: - dataset_schema = { - "expected_answer": {"type": "string"}, - "input_query": {"type": "string"}, - "generated_answer": {"type": "string"}, - } - - dataset_providers = [x for x in llama_stack_client.providers.list() if x.api == "datasetio"] - dataset_provider_id = dataset_providers[0].provider_id - - llama_stack_client.datasets.register( - dataset_id=dataset_id, - dataset_schema=dataset_schema, - url=dict(uri=test_url), - provider_id=dataset_provider_id, - ) - - -def test_register_unregister_dataset(llama_stack_client): - register_dataset(llama_stack_client) - response = llama_stack_client.datasets.list() - assert isinstance(response, list) - assert len(response) == 1 - assert response[0].identifier == "test_dataset" - - llama_stack_client.datasets.unregister("test_dataset") - response = llama_stack_client.datasets.list() - assert isinstance(response, list) - assert len(response) == 0 - - -def test_get_rows_paginated(llama_stack_client, dataset_for_test): - response = llama_stack_client.datasetio.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=3, - ) - assert isinstance(response.rows, list) - assert len(response.rows) == 3 - assert response.next_page_token == "3" - - # iterate over all rows - response = llama_stack_client.datasetio.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=2, - page_token=response.next_page_token, - ) - assert isinstance(response.rows, list) - assert len(response.rows) == 2 - assert response.next_page_token == "5" diff --git a/tests/integration/datasetio/__init__.py b/tests/integration/datasets/__init__.py similarity index 100% rename from tests/integration/datasetio/__init__.py rename to tests/integration/datasets/__init__.py diff --git a/tests/integration/datasetio/test_dataset.csv b/tests/integration/datasets/test_dataset.csv similarity index 100% rename from tests/integration/datasetio/test_dataset.csv rename to tests/integration/datasets/test_dataset.csv diff --git a/tests/integration/datasets/test_datasets.py b/tests/integration/datasets/test_datasets.py new file mode 100644 index 000000000..60db95f30 --- /dev/null +++ b/tests/integration/datasets/test_datasets.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import base64 +import mimetypes +import os + +import pytest + +# How to run this test: +# +# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets + + +def data_url_from_file(file_path: str) -> str: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url + + +@pytest.mark.parametrize( + "purpose, source, provider_id, limit", + [ + ( + "eval/messages-answer", + { + "type": "uri", + "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + }, + "huggingface", + 10, + ), + ( + "eval/messages-answer", + { + "type": "rows", + "rows": [ + { + "messages": [{"role": "user", "content": "Hello, world!"}], + "answer": "Hello, world!", + }, + { + "messages": [ + { + "role": "user", + "content": "What is the capital of France?", + } + ], + "answer": "Paris", + }, + ], + }, + "localfs", + 2, + ), + ( + "eval/messages-answer", + { + "type": "uri", + "uri": data_url_from_file(os.path.join(os.path.dirname(__file__), "test_dataset.csv")), + }, + "localfs", + 5, + ), + ], +) +def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit): + dataset = llama_stack_client.datasets.register( + purpose=purpose, + source=source, + ) + assert dataset.identifier is not None + assert dataset.provider_id == provider_id + iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=limit) + assert len(iterrow_response.data) == limit + + dataset_list = llama_stack_client.datasets.list() + assert dataset.identifier in [d.identifier for d in dataset_list] + + llama_stack_client.datasets.unregister(dataset.identifier) + dataset_list = llama_stack_client.datasets.list() + assert dataset.identifier not in [d.identifier for d in dataset_list] diff --git a/tests/integration/datasetio/test_rag_dataset.csv b/tests/integration/datasets/test_rag_dataset.csv similarity index 100% rename from tests/integration/datasetio/test_rag_dataset.csv rename to tests/integration/datasets/test_rag_dataset.csv