refactor: extract pagination logic into shared helper function

Move pagination logic from LocalFS and HuggingFace implementations into
a common helper function to ensure consistent pagination behavior across
providers. This reduces code duplication and centralizes pagination
logic in one place.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-03-24 17:49:19 +01:00
parent 9b478f3756
commit 8e15e3c1b8
No known key found for this signature in database
9 changed files with 130 additions and 73 deletions

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class PaginatedResponse(BaseModel):
"""A generic paginated response that follows a simple format.
:param data: The list of items for the current page
:param has_more: Whether there are more items available after this set
"""
data: List[Dict[str, Any]]
has_more: bool

View file

@ -6,23 +6,9 @@
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class IterrowsResponse(BaseModel):
"""
A paginated list of rows from a dataset.
:param data: The rows in the current page.
:param next_start_index: Index into dataset for the first row in the next page. None if there are no more rows.
"""
data: List[Dict[str, Any]]
next_start_index: Optional[int] = None
from llama_stack.schema_utils import webmethod
class DatasetStore(Protocol):
@ -34,15 +20,22 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
# TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
async def iterrows(
self,
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
"""Get a paginated list of rows from a dataset. Uses cursor-based pagination.
) -> PaginatedResponse:
"""Get a paginated list of rows from a dataset.
Uses offset-based pagination where:
- start_index: The starting index (0-based). If None, starts from beginning.
- limit: Number of items to return. If None or -1, returns all items.
The response includes:
- data: List of items for the current page
- has_more: Whether there are more items available after this set
:param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None.

View file

@ -12,7 +12,8 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
@ -497,7 +498,7 @@ class DatasetIORouter(DatasetIO):
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
) -> PaginatedResponse:
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)

View file

@ -7,9 +7,11 @@ from typing import Any, Dict, List, Optional
import pandas
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -92,24 +94,13 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
) -> PaginatedResponse:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
await dataset_impl.load()
start_index = start_index or 0
if limit is None or limit == -1:
end = len(dataset_impl)
else:
end = min(start_index + limit, len(dataset_impl))
rows = dataset_impl[start_index:end]
return IterrowsResponse(
data=rows,
next_start_index=end if end < len(dataset_impl) else None,
)
records = dataset_impl.df.to_dict("records")
return paginate_records(records, start_index, limit)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]

View file

@ -8,9 +8,11 @@ from urllib.parse import parse_qs, urlparse
import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig
@ -70,24 +72,13 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
) -> PaginatedResponse:
dataset_def = self.dataset_infos[dataset_id]
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
start_index = start_index or 0
if limit is None or limit == -1:
end = len(loaded_dataset)
else:
end = min(start_index + limit, len(loaded_dataset))
rows = [loaded_dataset[i] for i in range(start_index, end)]
return IterrowsResponse(
data=rows,
next_start_index=end if end < len(loaded_dataset) else None,
)
records = [loaded_dataset[i] for i in range(len(loaded_dataset))]
return paginate_records(records, start_index, limit)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_stack.apis.common.responses import PaginatedResponse
def paginate_records(
records: List[Dict[str, Any]],
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""Helper function to handle pagination of records consistently across implementations.
Inspired by stripe's pagination: https://docs.stripe.com/api/pagination
:param records: List of records to paginate
:param start_index: The starting index (0-based). If None, starts from beginning.
:param limit: Number of items to return. If None or -1, returns all items.
:return: PaginatedResponse with the paginated data
"""
# Handle special case for fetching all rows
if limit is None or limit == -1:
return PaginatedResponse(
data=records,
has_more=False,
)
# Use offset-based pagination
start_index = start_index or 0
end_index = min(start_index + limit, len(records))
page_data = records[start_index:end_index]
# Calculate if there are more records
has_more = end_index < len(records)
return PaginatedResponse(
data=page_data,
has_more=has_more,
)