feat(dataset api): (1.5/n) fix dataset registeration (#1659)

# What does this PR do?

- fix dataset registeration & iterrows
> NOTE: the URL endpoint is changed to datasetio due to flaky path
routing

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
```
LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/datasets/test_datasets.py
```
<img width="854" alt="image"
src="https://github.com/user-attachments/assets/0168b352-1c5a-48d1-8e9a-93141d418e54"
/>


[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-03-15 16:48:09 -07:00 committed by GitHub
parent 2c9d624910
commit a568bf3f9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 159 additions and 248 deletions

View file

@ -40,7 +40,7 @@
} }
], ],
"paths": { "paths": {
"/v1/datasets/{dataset_id}/append-rows": { "/v1/datasetio/append-rows/{dataset_id}": {
"post": { "post": {
"responses": { "responses": {
"200": { "200": {
@ -60,7 +60,7 @@
} }
}, },
"tags": [ "tags": [
"Datasets" "DatasetIO"
], ],
"description": "", "description": "",
"parameters": [ "parameters": [
@ -2177,7 +2177,7 @@
} }
} }
}, },
"/v1/datasets/{dataset_id}/iterrows": { "/v1/datasetio/iterrows/{dataset_id}": {
"get": { "get": {
"responses": { "responses": {
"200": { "200": {
@ -2204,7 +2204,7 @@
} }
}, },
"tags": [ "tags": [
"Datasets" "DatasetIO"
], ],
"description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.", "description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.",
"parameters": [ "parameters": [
@ -10274,7 +10274,7 @@
"name": "Benchmarks" "name": "Benchmarks"
}, },
{ {
"name": "Datasets" "name": "DatasetIO"
}, },
{ {
"name": "Datasets" "name": "Datasets"
@ -10342,7 +10342,7 @@
"Agents", "Agents",
"BatchInference (Coming Soon)", "BatchInference (Coming Soon)",
"Benchmarks", "Benchmarks",
"Datasets", "DatasetIO",
"Datasets", "Datasets",
"Eval", "Eval",
"Files", "Files",

View file

@ -10,7 +10,7 @@ info:
servers: servers:
- url: http://any-hosted-llama-stack.com - url: http://any-hosted-llama-stack.com
paths: paths:
/v1/datasets/{dataset_id}/append-rows: /v1/datasetio/append-rows/{dataset_id}:
post: post:
responses: responses:
'200': '200':
@ -26,7 +26,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Datasets - DatasetIO
description: '' description: ''
parameters: parameters:
- name: dataset_id - name: dataset_id
@ -1457,7 +1457,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/InvokeToolRequest' $ref: '#/components/schemas/InvokeToolRequest'
required: true required: true
/v1/datasets/{dataset_id}/iterrows: /v1/datasetio/iterrows/{dataset_id}:
get: get:
responses: responses:
'200': '200':
@ -1477,7 +1477,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Datasets - DatasetIO
description: >- description: >-
Get a paginated list of rows from a dataset. Uses cursor-based pagination. Get a paginated list of rows from a dataset. Uses cursor-based pagination.
parameters: parameters:
@ -6931,7 +6931,7 @@ tags:
Agents API for creating and interacting with agentic systems. Agents API for creating and interacting with agentic systems.
- name: BatchInference (Coming Soon) - name: BatchInference (Coming Soon)
- name: Benchmarks - name: Benchmarks
- name: Datasets - name: DatasetIO
- name: Datasets - name: Datasets
- name: Eval - name: Eval
x-displayName: >- x-displayName: >-
@ -6971,7 +6971,7 @@ x-tagGroups:
- Agents - Agents
- BatchInference (Coming Soon) - BatchInference (Coming Soon)
- Benchmarks - Benchmarks
- Datasets - DatasetIO
- Datasets - Datasets
- Eval - Eval
- Files - Files

View file

@ -552,8 +552,8 @@ class Generator:
print(op.defining_class.__name__) print(op.defining_class.__name__)
# TODO (xiyan): temporary fix for datasetio inner impl + datasets api # TODO (xiyan): temporary fix for datasetio inner impl + datasets api
if op.defining_class.__name__ in ["DatasetIO"]: # if op.defining_class.__name__ in ["DatasetIO"]:
op.defining_class.__name__ = "Datasets" # op.defining_class.__name__ = "Datasets"
doc_string = parse_type(op.func_ref) doc_string = parse_type(op.func_ref)
doc_params = dict( doc_params = dict(

View file

@ -34,7 +34,8 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used # keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore dataset_store: DatasetStore
@webmethod(route="/datasets/{dataset_id}/iterrows", method="GET") # TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
async def iterrows( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
@ -49,5 +50,5 @@ class DatasetIO(Protocol):
""" """
... ...
@webmethod(route="/datasets/{dataset_id}/append-rows", method="POST") @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...

View file

@ -371,9 +371,9 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
provider_dataset_id = dataset_id provider_dataset_id = dataset_id
# infer provider from source # infer provider from source
if source.type == DatasetType.rows: if source.type == DatasetType.rows.value:
provider_id = "localfs" provider_id = "localfs"
elif source.type == DatasetType.uri: elif source.type == DatasetType.uri.value:
# infer provider from uri # infer provider from uri
if source.uri.startswith("huggingface"): if source.uri.startswith("huggingface"):
provider_id = "huggingface" provider_id = "huggingface"

View file

@ -3,20 +3,14 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import pandas import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig from .config import LocalFSDatasetIOConfig
@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig
DATASETS_PREFIX = "localfs_datasets:" DATASETS_PREFIX = "localfs_datasets:"
class BaseDataset(ABC): class PandasDataframeDataset:
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError()
@abstractmethod
def __getitem__(self, idx):
raise NotImplementedError()
@abstractmethod
def load(self):
raise NotImplementedError()
@dataclass
class DatasetInfo:
dataset_def: Dataset
dataset_impl: BaseDataset
class PandasDataframeDataset(BaseDataset):
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None: def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.dataset_def = dataset_def self.dataset_def = dataset_def
@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset):
else: else:
return self.df.iloc[idx].to_dict() return self.df.iloc[idx].to_dict()
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
# note that we will drop any columns in dataset that are not in the schema
df = df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(df.columns) == len(self.dataset_def.dataset_schema)
# TODO: type checking against column types in dataset schema
return df
def load(self) -> None: def load(self) -> None:
if self.df is not None: if self.df is not None:
return return
df = get_dataframe_from_url(self.dataset_def.url) if self.dataset_def.source.type == "uri":
if df is None: self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}") elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
self.df = self._validate_dataset_schema(df) if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -99,29 +66,21 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
for dataset in stored_datasets: for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset) dataset = Dataset.model_validate_json(dataset)
dataset_impl = PandasDataframeDataset(dataset) self.dataset_infos[dataset.identifier] = dataset
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
)
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def register_dataset( async def register_dataset(
self, self,
dataset: Dataset, dataset_def: Dataset,
) -> None: ) -> None:
# Store in kvstore # Store in kvstore
key = f"{DATASETS_PREFIX}{dataset.identifier}" key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set( await self.kvstore.set(
key=key, key=key,
value=dataset.json(), value=dataset_def.model_dump_json(),
)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
) )
self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None: async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}" key = f"{DATASETS_PREFIX}{dataset_id}"
@ -134,51 +93,28 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
start_index: Optional[int] = None, start_index: Optional[int] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> IterrowsResponse: ) -> IterrowsResponse:
dataset_info = self.dataset_infos.get(dataset_id) dataset_def = self.dataset_infos[dataset_id]
dataset_info.dataset_impl.load() dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()
start_index = start_index or 0 start_index = start_index or 0
if limit is None or limit == -1: if limit is None or limit == -1:
end = len(dataset_info.dataset_impl) end = len(dataset_impl)
else: else:
end = min(start_index + limit, len(dataset_info.dataset_impl)) end = min(start_index + limit, len(dataset_impl))
rows = dataset_info.dataset_impl[start_index:end] rows = dataset_impl[start_index:end]
return IterrowsResponse( return IterrowsResponse(
data=rows, data=rows,
next_index=end if end < len(dataset_info.dataset_impl) else None, next_index=end if end < len(dataset_impl) else None,
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_info = self.dataset_infos.get(dataset_id) dataset_def = self.dataset_infos[dataset_id]
if dataset_info is None: dataset_impl = PandasDataframeDataset(dataset_def)
raise ValueError(f"Dataset with id {dataset_id} not found")
dataset_impl = dataset_info.dataset_impl
dataset_impl.load() dataset_impl.load()
new_rows_df = pandas.DataFrame(rows) new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
dataset_impl.df.to_csv(file_path, index=False)
elif parsed_url.scheme == "data":
# For data URLs, we need to update the base64-encoded content
if not parsed_url.path.startswith("text/csv;base64,"):
raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
)

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
import datasets as hf_datasets import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig from .config import HuggingfaceDatasetIOConfig
@ -18,22 +18,14 @@ from .config import HuggingfaceDatasetIOConfig
DATASETS_PREFIX = "datasets:" DATASETS_PREFIX = "datasets:"
def load_hf_dataset(dataset_def: Dataset): def parse_hf_params(dataset_def: Dataset):
if dataset_def.metadata.get("path", None): uri = dataset_def.source.uri
dataset = hf_datasets.load_dataset(**dataset_def.metadata) parsed_uri = urlparse(uri)
else: params = parse_qs(parsed_uri.query)
df = get_dataframe_from_url(dataset_def.url) params = {k: v[0] for k, v in params.items()}
path = parsed_uri.path.lstrip("/")
if df is None: return path, params
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
dataset = hf_datasets.Dataset.from_pandas(df)
# drop columns not specified by schema
if dataset_def.dataset_schema:
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
return dataset
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -64,7 +56,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
key = f"{DATASETS_PREFIX}{dataset_def.identifier}" key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set( await self.kvstore.set(
key=key, key=key,
value=dataset_def.json(), value=dataset_def.model_dump_json(),
) )
self.dataset_infos[dataset_def.identifier] = dataset_def self.dataset_infos[dataset_def.identifier] = dataset_def
@ -80,7 +72,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
limit: Optional[int] = None, limit: Optional[int] = None,
) -> IterrowsResponse: ) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def) path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
start_index = start_index or 0 start_index = start_index or 0
@ -98,7 +91,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def) path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
# Convert rows to HF Dataset format # Convert rows to HF Dataset format
new_dataset = hf_datasets.Dataset.from_list(rows) new_dataset = hf_datasets.Dataset.from_list(rows)

View file

@ -10,18 +10,17 @@ from urllib.parse import unquote
import pandas import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.providers.utils.memory.vector_store import parse_data_url from llama_stack.providers.utils.memory.vector_store import parse_data_url
def get_dataframe_from_url(url: URL): def get_dataframe_from_uri(uri: str):
df = None df = None
if url.uri.endswith(".csv"): if uri.endswith(".csv"):
df = pandas.read_csv(url.uri) df = pandas.read_csv(uri)
elif url.uri.endswith(".xlsx"): elif uri.endswith(".xlsx"):
df = pandas.read_excel(url.uri) df = pandas.read_excel(uri)
elif url.uri.startswith("data:"): elif uri.startswith("data:"):
parts = parse_data_url(url.uri) parts = parse_data_url(uri)
data = parts["data"] data = parts["data"]
if parts["is_base64"]: if parts["is_base64"]:
data = base64.b64decode(data) data = base64.b64decode(data)
@ -39,6 +38,6 @@ def get_dataframe_from_url(url: URL):
else: else:
df = pandas.read_excel(data_bytes) df = pandas.read_excel(data_bytes)
else: else:
raise ValueError(f"Unsupported file type: {url}") raise ValueError(f"Unsupported file type: {uri}")
return df return df

View file

@ -1,114 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from pathlib import Path
import pytest
# How to run this test:
#
# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasetio
@pytest.fixture
def dataset_for_test(llama_stack_client):
dataset_id = "test_dataset"
register_dataset(llama_stack_client, dataset_id=dataset_id)
yield
# Teardown - this always runs, even if the test fails
try:
llama_stack_client.datasets.unregister(dataset_id)
except Exception as e:
print(f"Warning: Failed to unregister test_dataset: {e}")
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
def register_dataset(llama_stack_client, for_generation=False, for_rag=False, dataset_id="test_dataset"):
if for_rag:
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
else:
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
if for_generation:
dataset_schema = {
"expected_answer": {"type": "string"},
"input_query": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"},
}
elif for_rag:
dataset_schema = {
"expected_answer": {"type": "string"},
"input_query": {"type": "string"},
"generated_answer": {"type": "string"},
"context": {"type": "string"},
}
else:
dataset_schema = {
"expected_answer": {"type": "string"},
"input_query": {"type": "string"},
"generated_answer": {"type": "string"},
}
dataset_providers = [x for x in llama_stack_client.providers.list() if x.api == "datasetio"]
dataset_provider_id = dataset_providers[0].provider_id
llama_stack_client.datasets.register(
dataset_id=dataset_id,
dataset_schema=dataset_schema,
url=dict(uri=test_url),
provider_id=dataset_provider_id,
)
def test_register_unregister_dataset(llama_stack_client):
register_dataset(llama_stack_client)
response = llama_stack_client.datasets.list()
assert isinstance(response, list)
assert len(response) == 1
assert response[0].identifier == "test_dataset"
llama_stack_client.datasets.unregister("test_dataset")
response = llama_stack_client.datasets.list()
assert isinstance(response, list)
assert len(response) == 0
def test_get_rows_paginated(llama_stack_client, dataset_for_test):
response = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 3
assert response.next_page_token == "3"
# iterate over all rows
response = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=2,
page_token=response.next_page_token,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 2
assert response.next_page_token == "5"

View file

@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
import pytest
# How to run this test:
#
# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
@pytest.mark.parametrize(
"purpose, source, provider_id, limit",
[
(
"eval/messages-answer",
{
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
"huggingface",
10,
),
(
"eval/messages-answer",
{
"type": "rows",
"rows": [
{
"messages": [{"role": "user", "content": "Hello, world!"}],
"answer": "Hello, world!",
},
{
"messages": [
{
"role": "user",
"content": "What is the capital of France?",
}
],
"answer": "Paris",
},
],
},
"localfs",
2,
),
(
"eval/messages-answer",
{
"type": "uri",
"uri": data_url_from_file(os.path.join(os.path.dirname(__file__), "test_dataset.csv")),
},
"localfs",
5,
),
],
)
def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit):
dataset = llama_stack_client.datasets.register(
purpose=purpose,
source=source,
)
assert dataset.identifier is not None
assert dataset.provider_id == provider_id
iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=limit)
assert len(iterrow_response.data) == limit
dataset_list = llama_stack_client.datasets.list()
assert dataset.identifier in [d.identifier for d in dataset_list]
llama_stack_client.datasets.unregister(dataset.identifier)
dataset_list = llama_stack_client.datasets.list()
assert dataset.identifier not in [d.identifier for d in dataset_list]