forked from phoenix-oss/llama-stack-mirror
feat(dataset api): (1.5/n) fix dataset registeration (#1659)
# What does this PR do? - fix dataset registeration & iterrows > NOTE: the URL endpoint is changed to datasetio due to flaky path routing [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/datasets/test_datasets.py ``` <img width="854" alt="image" src="https://github.com/user-attachments/assets/0168b352-1c5a-48d1-8e9a-93141d418e54" /> [//]: # (## Documentation)
This commit is contained in:
parent
2c9d624910
commit
a568bf3f9d
13 changed files with 159 additions and 248 deletions
12
docs/_static/llama-stack-spec.html
vendored
12
docs/_static/llama-stack-spec.html
vendored
|
@ -40,7 +40,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"paths": {
|
"paths": {
|
||||||
"/v1/datasets/{dataset_id}/append-rows": {
|
"/v1/datasetio/append-rows/{dataset_id}": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -60,7 +60,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tags": [
|
"tags": [
|
||||||
"Datasets"
|
"DatasetIO"
|
||||||
],
|
],
|
||||||
"description": "",
|
"description": "",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
|
@ -2177,7 +2177,7 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/datasets/{dataset_id}/iterrows": {
|
"/v1/datasetio/iterrows/{dataset_id}": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -2204,7 +2204,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tags": [
|
"tags": [
|
||||||
"Datasets"
|
"DatasetIO"
|
||||||
],
|
],
|
||||||
"description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.",
|
"description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
|
@ -10274,7 +10274,7 @@
|
||||||
"name": "Benchmarks"
|
"name": "Benchmarks"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Datasets"
|
"name": "DatasetIO"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Datasets"
|
"name": "Datasets"
|
||||||
|
@ -10342,7 +10342,7 @@
|
||||||
"Agents",
|
"Agents",
|
||||||
"BatchInference (Coming Soon)",
|
"BatchInference (Coming Soon)",
|
||||||
"Benchmarks",
|
"Benchmarks",
|
||||||
"Datasets",
|
"DatasetIO",
|
||||||
"Datasets",
|
"Datasets",
|
||||||
"Eval",
|
"Eval",
|
||||||
"Files",
|
"Files",
|
||||||
|
|
12
docs/_static/llama-stack-spec.yaml
vendored
12
docs/_static/llama-stack-spec.yaml
vendored
|
@ -10,7 +10,7 @@ info:
|
||||||
servers:
|
servers:
|
||||||
- url: http://any-hosted-llama-stack.com
|
- url: http://any-hosted-llama-stack.com
|
||||||
paths:
|
paths:
|
||||||
/v1/datasets/{dataset_id}/append-rows:
|
/v1/datasetio/append-rows/{dataset_id}:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -26,7 +26,7 @@ paths:
|
||||||
default:
|
default:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- Datasets
|
- DatasetIO
|
||||||
description: ''
|
description: ''
|
||||||
parameters:
|
parameters:
|
||||||
- name: dataset_id
|
- name: dataset_id
|
||||||
|
@ -1457,7 +1457,7 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/InvokeToolRequest'
|
$ref: '#/components/schemas/InvokeToolRequest'
|
||||||
required: true
|
required: true
|
||||||
/v1/datasets/{dataset_id}/iterrows:
|
/v1/datasetio/iterrows/{dataset_id}:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -1477,7 +1477,7 @@ paths:
|
||||||
default:
|
default:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- Datasets
|
- DatasetIO
|
||||||
description: >-
|
description: >-
|
||||||
Get a paginated list of rows from a dataset. Uses cursor-based pagination.
|
Get a paginated list of rows from a dataset. Uses cursor-based pagination.
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -6931,7 +6931,7 @@ tags:
|
||||||
Agents API for creating and interacting with agentic systems.
|
Agents API for creating and interacting with agentic systems.
|
||||||
- name: BatchInference (Coming Soon)
|
- name: BatchInference (Coming Soon)
|
||||||
- name: Benchmarks
|
- name: Benchmarks
|
||||||
- name: Datasets
|
- name: DatasetIO
|
||||||
- name: Datasets
|
- name: Datasets
|
||||||
- name: Eval
|
- name: Eval
|
||||||
x-displayName: >-
|
x-displayName: >-
|
||||||
|
@ -6971,7 +6971,7 @@ x-tagGroups:
|
||||||
- Agents
|
- Agents
|
||||||
- BatchInference (Coming Soon)
|
- BatchInference (Coming Soon)
|
||||||
- Benchmarks
|
- Benchmarks
|
||||||
- Datasets
|
- DatasetIO
|
||||||
- Datasets
|
- Datasets
|
||||||
- Eval
|
- Eval
|
||||||
- Files
|
- Files
|
||||||
|
|
|
@ -552,8 +552,8 @@ class Generator:
|
||||||
print(op.defining_class.__name__)
|
print(op.defining_class.__name__)
|
||||||
|
|
||||||
# TODO (xiyan): temporary fix for datasetio inner impl + datasets api
|
# TODO (xiyan): temporary fix for datasetio inner impl + datasets api
|
||||||
if op.defining_class.__name__ in ["DatasetIO"]:
|
# if op.defining_class.__name__ in ["DatasetIO"]:
|
||||||
op.defining_class.__name__ = "Datasets"
|
# op.defining_class.__name__ = "Datasets"
|
||||||
|
|
||||||
doc_string = parse_type(op.func_ref)
|
doc_string = parse_type(op.func_ref)
|
||||||
doc_params = dict(
|
doc_params = dict(
|
||||||
|
|
|
@ -34,7 +34,8 @@ class DatasetIO(Protocol):
|
||||||
# keeping for aligning with inference/safety, but this is not used
|
# keeping for aligning with inference/safety, but this is not used
|
||||||
dataset_store: DatasetStore
|
dataset_store: DatasetStore
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id}/iterrows", method="GET")
|
# TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
|
||||||
|
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
|
||||||
async def iterrows(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
@ -49,5 +50,5 @@ class DatasetIO(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id}/append-rows", method="POST")
|
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
||||||
|
|
|
@ -371,9 +371,9 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
provider_dataset_id = dataset_id
|
provider_dataset_id = dataset_id
|
||||||
|
|
||||||
# infer provider from source
|
# infer provider from source
|
||||||
if source.type == DatasetType.rows:
|
if source.type == DatasetType.rows.value:
|
||||||
provider_id = "localfs"
|
provider_id = "localfs"
|
||||||
elif source.type == DatasetType.uri:
|
elif source.type == DatasetType.uri.value:
|
||||||
# infer provider from uri
|
# infer provider from uri
|
||||||
if source.uri.startswith("huggingface"):
|
if source.uri.startswith("huggingface"):
|
||||||
provider_id = "huggingface"
|
provider_id = "huggingface"
|
||||||
|
|
|
@ -3,20 +3,14 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import pandas
|
import pandas
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from .config import LocalFSDatasetIOConfig
|
from .config import LocalFSDatasetIOConfig
|
||||||
|
@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig
|
||||||
DATASETS_PREFIX = "localfs_datasets:"
|
DATASETS_PREFIX = "localfs_datasets:"
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(ABC):
|
class PandasDataframeDataset:
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __len__(self) -> int:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load(self):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DatasetInfo:
|
|
||||||
dataset_def: Dataset
|
|
||||||
dataset_impl: BaseDataset
|
|
||||||
|
|
||||||
|
|
||||||
class PandasDataframeDataset(BaseDataset):
|
|
||||||
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
|
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.dataset_def = dataset_def
|
self.dataset_def = dataset_def
|
||||||
|
@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset):
|
||||||
else:
|
else:
|
||||||
return self.df.iloc[idx].to_dict()
|
return self.df.iloc[idx].to_dict()
|
||||||
|
|
||||||
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
|
|
||||||
# note that we will drop any columns in dataset that are not in the schema
|
|
||||||
df = df[self.dataset_def.dataset_schema.keys()]
|
|
||||||
# check all columns in dataset schema are present
|
|
||||||
assert len(df.columns) == len(self.dataset_def.dataset_schema)
|
|
||||||
# TODO: type checking against column types in dataset schema
|
|
||||||
return df
|
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
if self.df is not None:
|
if self.df is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
df = get_dataframe_from_url(self.dataset_def.url)
|
if self.dataset_def.source.type == "uri":
|
||||||
if df is None:
|
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
|
||||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
elif self.dataset_def.source.type == "rows":
|
||||||
|
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
||||||
|
|
||||||
self.df = self._validate_dataset_schema(df)
|
if self.df is None:
|
||||||
|
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||||
|
|
||||||
|
|
||||||
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
@ -99,29 +66,21 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
|
||||||
for dataset in stored_datasets:
|
for dataset in stored_datasets:
|
||||||
dataset = Dataset.model_validate_json(dataset)
|
dataset = Dataset.model_validate_json(dataset)
|
||||||
dataset_impl = PandasDataframeDataset(dataset)
|
self.dataset_infos[dataset.identifier] = dataset
|
||||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
|
||||||
dataset_def=dataset,
|
|
||||||
dataset_impl=dataset_impl,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def register_dataset(
|
async def register_dataset(
|
||||||
self,
|
self,
|
||||||
dataset: Dataset,
|
dataset_def: Dataset,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Store in kvstore
|
# Store in kvstore
|
||||||
key = f"{DATASETS_PREFIX}{dataset.identifier}"
|
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=key,
|
key=key,
|
||||||
value=dataset.json(),
|
value=dataset_def.model_dump_json(),
|
||||||
)
|
|
||||||
dataset_impl = PandasDataframeDataset(dataset)
|
|
||||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
|
||||||
dataset_def=dataset,
|
|
||||||
dataset_impl=dataset_impl,
|
|
||||||
)
|
)
|
||||||
|
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||||
|
|
||||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
key = f"{DATASETS_PREFIX}{dataset_id}"
|
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||||
|
@ -134,51 +93,28 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
start_index: Optional[int] = None,
|
start_index: Optional[int] = None,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> IterrowsResponse:
|
) -> IterrowsResponse:
|
||||||
dataset_info = self.dataset_infos.get(dataset_id)
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
dataset_info.dataset_impl.load()
|
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||||
|
dataset_impl.load()
|
||||||
|
|
||||||
start_index = start_index or 0
|
start_index = start_index or 0
|
||||||
|
|
||||||
if limit is None or limit == -1:
|
if limit is None or limit == -1:
|
||||||
end = len(dataset_info.dataset_impl)
|
end = len(dataset_impl)
|
||||||
else:
|
else:
|
||||||
end = min(start_index + limit, len(dataset_info.dataset_impl))
|
end = min(start_index + limit, len(dataset_impl))
|
||||||
|
|
||||||
rows = dataset_info.dataset_impl[start_index:end]
|
rows = dataset_impl[start_index:end]
|
||||||
|
|
||||||
return IterrowsResponse(
|
return IterrowsResponse(
|
||||||
data=rows,
|
data=rows,
|
||||||
next_index=end if end < len(dataset_info.dataset_impl) else None,
|
next_index=end if end < len(dataset_impl) else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
dataset_info = self.dataset_infos.get(dataset_id)
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
if dataset_info is None:
|
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||||
raise ValueError(f"Dataset with id {dataset_id} not found")
|
|
||||||
|
|
||||||
dataset_impl = dataset_info.dataset_impl
|
|
||||||
dataset_impl.load()
|
dataset_impl.load()
|
||||||
|
|
||||||
new_rows_df = pandas.DataFrame(rows)
|
new_rows_df = pandas.DataFrame(rows)
|
||||||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
|
||||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||||
|
|
||||||
url = str(dataset_info.dataset_def.url.uri)
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
|
|
||||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
|
||||||
file_path = parsed_url.path
|
|
||||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
||||||
dataset_impl.df.to_csv(file_path, index=False)
|
|
||||||
elif parsed_url.scheme == "data":
|
|
||||||
# For data URLs, we need to update the base64-encoded content
|
|
||||||
if not parsed_url.path.startswith("text/csv;base64,"):
|
|
||||||
raise ValueError("Data URL must be a base64-encoded CSV")
|
|
||||||
|
|
||||||
csv_buffer = dataset_impl.df.to_csv(index=False)
|
|
||||||
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
|
|
||||||
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
|
||||||
)
|
|
||||||
|
|
|
@ -4,13 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
import datasets as hf_datasets
|
import datasets as hf_datasets
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from .config import HuggingfaceDatasetIOConfig
|
from .config import HuggingfaceDatasetIOConfig
|
||||||
|
@ -18,22 +18,14 @@ from .config import HuggingfaceDatasetIOConfig
|
||||||
DATASETS_PREFIX = "datasets:"
|
DATASETS_PREFIX = "datasets:"
|
||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(dataset_def: Dataset):
|
def parse_hf_params(dataset_def: Dataset):
|
||||||
if dataset_def.metadata.get("path", None):
|
uri = dataset_def.source.uri
|
||||||
dataset = hf_datasets.load_dataset(**dataset_def.metadata)
|
parsed_uri = urlparse(uri)
|
||||||
else:
|
params = parse_qs(parsed_uri.query)
|
||||||
df = get_dataframe_from_url(dataset_def.url)
|
params = {k: v[0] for k, v in params.items()}
|
||||||
|
path = parsed_uri.path.lstrip("/")
|
||||||
|
|
||||||
if df is None:
|
return path, params
|
||||||
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
|
||||||
|
|
||||||
dataset = hf_datasets.Dataset.from_pandas(df)
|
|
||||||
|
|
||||||
# drop columns not specified by schema
|
|
||||||
if dataset_def.dataset_schema:
|
|
||||||
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
@ -64,7 +56,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=key,
|
key=key,
|
||||||
value=dataset_def.json(),
|
value=dataset_def.model_dump_json(),
|
||||||
)
|
)
|
||||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||||
|
|
||||||
|
@ -80,7 +72,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> IterrowsResponse:
|
) -> IterrowsResponse:
|
||||||
dataset_def = self.dataset_infos[dataset_id]
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
loaded_dataset = load_hf_dataset(dataset_def)
|
path, params = parse_hf_params(dataset_def)
|
||||||
|
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||||
|
|
||||||
start_index = start_index or 0
|
start_index = start_index or 0
|
||||||
|
|
||||||
|
@ -98,7 +91,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
dataset_def = self.dataset_infos[dataset_id]
|
dataset_def = self.dataset_infos[dataset_id]
|
||||||
loaded_dataset = load_hf_dataset(dataset_def)
|
path, params = parse_hf_params(dataset_def)
|
||||||
|
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||||
|
|
||||||
# Convert rows to HF Dataset format
|
# Convert rows to HF Dataset format
|
||||||
new_dataset = hf_datasets.Dataset.from_list(rows)
|
new_dataset = hf_datasets.Dataset.from_list(rows)
|
||||||
|
|
|
@ -10,18 +10,17 @@ from urllib.parse import unquote
|
||||||
|
|
||||||
import pandas
|
import pandas
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||||
|
|
||||||
|
|
||||||
def get_dataframe_from_url(url: URL):
|
def get_dataframe_from_uri(uri: str):
|
||||||
df = None
|
df = None
|
||||||
if url.uri.endswith(".csv"):
|
if uri.endswith(".csv"):
|
||||||
df = pandas.read_csv(url.uri)
|
df = pandas.read_csv(uri)
|
||||||
elif url.uri.endswith(".xlsx"):
|
elif uri.endswith(".xlsx"):
|
||||||
df = pandas.read_excel(url.uri)
|
df = pandas.read_excel(uri)
|
||||||
elif url.uri.startswith("data:"):
|
elif uri.startswith("data:"):
|
||||||
parts = parse_data_url(url.uri)
|
parts = parse_data_url(uri)
|
||||||
data = parts["data"]
|
data = parts["data"]
|
||||||
if parts["is_base64"]:
|
if parts["is_base64"]:
|
||||||
data = base64.b64decode(data)
|
data = base64.b64decode(data)
|
||||||
|
@ -39,6 +38,6 @@ def get_dataframe_from_url(url: URL):
|
||||||
else:
|
else:
|
||||||
df = pandas.read_excel(data_bytes)
|
df = pandas.read_excel(data_bytes)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported file type: {url}")
|
raise ValueError(f"Unsupported file type: {uri}")
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
|
@ -1,114 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import mimetypes
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# How to run this test:
|
|
||||||
#
|
|
||||||
# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasetio
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dataset_for_test(llama_stack_client):
|
|
||||||
dataset_id = "test_dataset"
|
|
||||||
register_dataset(llama_stack_client, dataset_id=dataset_id)
|
|
||||||
yield
|
|
||||||
# Teardown - this always runs, even if the test fails
|
|
||||||
try:
|
|
||||||
llama_stack_client.datasets.unregister(dataset_id)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Failed to unregister test_dataset: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def data_url_from_file(file_path: str) -> str:
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
|
||||||
|
|
||||||
with open(file_path, "rb") as file:
|
|
||||||
file_content = file.read()
|
|
||||||
|
|
||||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
|
||||||
mime_type, _ = mimetypes.guess_type(file_path)
|
|
||||||
|
|
||||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
|
||||||
|
|
||||||
return data_url
|
|
||||||
|
|
||||||
|
|
||||||
def register_dataset(llama_stack_client, for_generation=False, for_rag=False, dataset_id="test_dataset"):
|
|
||||||
if for_rag:
|
|
||||||
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
|
|
||||||
else:
|
|
||||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
|
||||||
test_url = data_url_from_file(str(test_file))
|
|
||||||
|
|
||||||
if for_generation:
|
|
||||||
dataset_schema = {
|
|
||||||
"expected_answer": {"type": "string"},
|
|
||||||
"input_query": {"type": "string"},
|
|
||||||
"chat_completion_input": {"type": "chat_completion_input"},
|
|
||||||
}
|
|
||||||
elif for_rag:
|
|
||||||
dataset_schema = {
|
|
||||||
"expected_answer": {"type": "string"},
|
|
||||||
"input_query": {"type": "string"},
|
|
||||||
"generated_answer": {"type": "string"},
|
|
||||||
"context": {"type": "string"},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
dataset_schema = {
|
|
||||||
"expected_answer": {"type": "string"},
|
|
||||||
"input_query": {"type": "string"},
|
|
||||||
"generated_answer": {"type": "string"},
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset_providers = [x for x in llama_stack_client.providers.list() if x.api == "datasetio"]
|
|
||||||
dataset_provider_id = dataset_providers[0].provider_id
|
|
||||||
|
|
||||||
llama_stack_client.datasets.register(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
dataset_schema=dataset_schema,
|
|
||||||
url=dict(uri=test_url),
|
|
||||||
provider_id=dataset_provider_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_register_unregister_dataset(llama_stack_client):
|
|
||||||
register_dataset(llama_stack_client)
|
|
||||||
response = llama_stack_client.datasets.list()
|
|
||||||
assert isinstance(response, list)
|
|
||||||
assert len(response) == 1
|
|
||||||
assert response[0].identifier == "test_dataset"
|
|
||||||
|
|
||||||
llama_stack_client.datasets.unregister("test_dataset")
|
|
||||||
response = llama_stack_client.datasets.list()
|
|
||||||
assert isinstance(response, list)
|
|
||||||
assert len(response) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_rows_paginated(llama_stack_client, dataset_for_test):
|
|
||||||
response = llama_stack_client.datasetio.get_rows_paginated(
|
|
||||||
dataset_id="test_dataset",
|
|
||||||
rows_in_page=3,
|
|
||||||
)
|
|
||||||
assert isinstance(response.rows, list)
|
|
||||||
assert len(response.rows) == 3
|
|
||||||
assert response.next_page_token == "3"
|
|
||||||
|
|
||||||
# iterate over all rows
|
|
||||||
response = llama_stack_client.datasetio.get_rows_paginated(
|
|
||||||
dataset_id="test_dataset",
|
|
||||||
rows_in_page=2,
|
|
||||||
page_token=response.next_page_token,
|
|
||||||
)
|
|
||||||
assert isinstance(response.rows, list)
|
|
||||||
assert len(response.rows) == 2
|
|
||||||
assert response.next_page_token == "5"
|
|
95
tests/integration/datasets/test_datasets.py
Normal file
95
tests/integration/datasets/test_datasets.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasets
|
||||||
|
|
||||||
|
|
||||||
|
def data_url_from_file(file_path: str) -> str:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
file_content = file.read()
|
||||||
|
|
||||||
|
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_path)
|
||||||
|
|
||||||
|
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||||
|
|
||||||
|
return data_url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"purpose, source, provider_id, limit",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"eval/messages-answer",
|
||||||
|
{
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
|
"huggingface",
|
||||||
|
10,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"eval/messages-answer",
|
||||||
|
{
|
||||||
|
"type": "rows",
|
||||||
|
"rows": [
|
||||||
|
{
|
||||||
|
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||||
|
"answer": "Hello, world!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the capital of France?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"answer": "Paris",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"localfs",
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"eval/messages-answer",
|
||||||
|
{
|
||||||
|
"type": "uri",
|
||||||
|
"uri": data_url_from_file(os.path.join(os.path.dirname(__file__), "test_dataset.csv")),
|
||||||
|
},
|
||||||
|
"localfs",
|
||||||
|
5,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit):
|
||||||
|
dataset = llama_stack_client.datasets.register(
|
||||||
|
purpose=purpose,
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
assert dataset.identifier is not None
|
||||||
|
assert dataset.provider_id == provider_id
|
||||||
|
iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=limit)
|
||||||
|
assert len(iterrow_response.data) == limit
|
||||||
|
|
||||||
|
dataset_list = llama_stack_client.datasets.list()
|
||||||
|
assert dataset.identifier in [d.identifier for d in dataset_list]
|
||||||
|
|
||||||
|
llama_stack_client.datasets.unregister(dataset.identifier)
|
||||||
|
dataset_list = llama_stack_client.datasets.list()
|
||||||
|
assert dataset.identifier not in [d.identifier for d in dataset_list]
|
Loading…
Add table
Add a link
Reference in a new issue