llama-stack/llama_stack/apis/datasetio/datasetio.py
Sébastien Han 2ffa2b77ed
refactor: extract pagination logic into shared helper function (#1770)
# What does this PR do?

Move pagination logic from LocalFS and HuggingFace implementations into
a common helper function to ensure consistent pagination behavior across
providers. This reduces code duplication and centralizes pagination
logic in one place.


## Test Plan

Run this script:

```
from llama_stack_client import LlamaStackClient

# Initialize the client
client = LlamaStackClient(base_url="http://localhost:8321")

# Register a dataset
response = client.datasets.register(
    purpose="eval/messages-answer",  # or "eval/question-answer" or "post-training/messages"
    source={"type": "uri", "uri": "huggingface://datasets/llamastack/simpleqa?split=train"},
    dataset_id="my_dataset",  # optional, will be auto-generated if not provided
    metadata={"description": "My evaluation dataset"},  # optional
)

# Verify the dataset was registered by listing all datasets
datasets = client.datasets.list()
print(f"Registered datasets: {[d.identifier for d in datasets]}")

# You can then access the data using the datasetio API
# rows = client.datasets.iterrows(dataset_id="my_dataset", start_index=1, limit=2)
rows = client.datasets.iterrows(dataset_id="my_dataset")
print(f"Data: {rows.data}")
```

And play with `start_index` and `limit`.

[//]: # (## Documentation)

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-03-31 13:08:29 -07:00

47 lines
1.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.schema_utils import webmethod
class DatasetStore(Protocol):
def get_dataset(self, dataset_id: str) -> Dataset: ...
@runtime_checkable
class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
async def iterrows(
self,
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> PaginatedResponse:
"""Get a paginated list of rows from a dataset.
Uses offset-based pagination where:
- start_index: The starting index (0-based). If None, starts from beginning.
- limit: Number of items to return. If None or -1, returns all items.
The response includes:
- data: List of items for the current page
- has_more: Whether there are more items available after this set
:param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get.
"""
...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...