feat(api): (1/n) datasets api clean up (#1573)

## PR Stack
- https://github.com/meta-llama/llama-stack/pull/1573
- https://github.com/meta-llama/llama-stack/pull/1625
- https://github.com/meta-llama/llama-stack/pull/1656
- https://github.com/meta-llama/llama-stack/pull/1657
- https://github.com/meta-llama/llama-stack/pull/1658
- https://github.com/meta-llama/llama-stack/pull/1659
- https://github.com/meta-llama/llama-stack/pull/1660

**Client SDK**
- https://github.com/meta-llama/llama-stack-client-python/pull/203

**CI**
- 1391130488
<img width="1042" alt="image"
src="https://github.com/user-attachments/assets/69636067-376d-436b-9204-896e2dd490ca"
/>
-- the test_rag_agent_with_attachments is flaky and not related to this
PR

## Doc
<img width="789" alt="image"
src="https://github.com/user-attachments/assets/b88390f3-73d6-4483-b09a-a192064e32d9"
/>


## Client Usage
```python
client.datasets.register(
    source={
        "type": "uri",
        "uri": "lsfs://mydata.jsonl",
    },
    schema="jsonl_messages",
    # optional 
    dataset_id="my_first_train_data"
)

# quick prototype debugging
client.datasets.register(
    data_reference={
        "type": "rows",
        "rows": [
                "messages": [...],
        ],
    },
    schema="jsonl_messages",
)
```

## Test Plan
- CI:
1387805545

```
LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/datasets/test_datasets.py
```

```
LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/scoring/test_scoring.py
```

```
pytest -v -s --nbval-lax ./docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb
```
This commit is contained in:
Xi Yan 2025-03-17 16:55:45 -07:00 committed by GitHub
parent 3b35a39b8b
commit 5287b437ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 2593 additions and 2296 deletions

View file

@ -3,20 +3,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig
@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig
DATASETS_PREFIX = "localfs_datasets:"
class BaseDataset(ABC):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError()
@abstractmethod
def __getitem__(self, idx):
raise NotImplementedError()
@abstractmethod
def load(self):
raise NotImplementedError()
@dataclass
class DatasetInfo:
dataset_def: Dataset
dataset_impl: BaseDataset
class PandasDataframeDataset(BaseDataset):
class PandasDataframeDataset:
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.dataset_def = dataset_def
@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset):
else:
return self.df.iloc[idx].to_dict()
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
# note that we will drop any columns in dataset that are not in the schema
df = df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(df.columns) == len(self.dataset_def.dataset_schema)
# TODO: type checking against column types in dataset schema
return df
def load(self) -> None:
if self.df is not None:
return
df = get_dataframe_from_url(self.dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
if self.dataset_def.source.type == "uri":
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
self.df = self._validate_dataset_schema(df)
if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -99,95 +66,55 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
)
self.dataset_infos[dataset.identifier] = dataset
async def shutdown(self) -> None: ...
async def register_dataset(
self,
dataset: Dataset,
dataset_def: Dataset,
) -> None:
# Store in kvstore
key = f"{DATASETS_PREFIX}{dataset.identifier}"
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
key=key,
value=dataset.json(),
)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
value=dataset_def.model_dump_json(),
)
self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load()
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
start_index = start_index or 0
if page_token is None or len(page_token) == 0:
next_page_token = 0
if limit is None or limit == -1:
end = len(dataset_impl)
else:
next_page_token = int(page_token)
end = min(start_index + limit, len(dataset_impl))
start = next_page_token
if rows_in_page == -1:
end = len(dataset_info.dataset_impl)
else:
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
rows = dataset_impl[start_index:end]
rows = dataset_info.dataset_impl[start:end]
return PaginatedRowsResult(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
return IterrowsResponse(
data=rows,
next_start_index=end if end < len(dataset_impl) else None,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_info = self.dataset_infos.get(dataset_id)
if dataset_info is None:
raise ValueError(f"Dataset with id {dataset_id} not found")
dataset_impl = dataset_info.dataset_impl
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
dataset_impl.df.to_csv(file_path, index=False)
elif parsed_url.scheme == "data":
# For data URLs, we need to update the base64-encoded content
if not parsed_url.path.startswith("text/csv;base64,"):
raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
)

View file

@ -14,16 +14,11 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
get_valid_schemas,
validate_dataset_schema,
)
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
@ -88,15 +83,17 @@ class MetaReferenceEvalImpl(
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
all_rows = await self.datasetio_api.get_rows_paginated(
# TODO (xiyan): validate dataset schema
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
res = await self.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=all_rows.rows,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)

View file

@ -328,13 +328,13 @@ class LoraFinetuningSingleDevice:
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(dataset_id: str):
return await self.datasetio_api.get_rows_paginated(
return await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows
rows = all_rows.data
await validate_input_dataset_schema(
datasets_api=self.datasets_api,

View file

@ -24,7 +24,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn,
)
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
@ -82,12 +84,12 @@ class BasicScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
res = await self.score(
input_rows=all_rows.rows,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
)
if save_results_dataset:

View file

@ -167,11 +167,11 @@ class BraintrustScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
res = await self.score(input_rows=all_rows.data, scoring_functions=scoring_functions)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()

View file

@ -72,12 +72,12 @@ class LlmAsJudgeScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
res = await self.score(
input_rows=all_rows.rows,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
)
if save_results_dataset: