fix iterrows

This commit is contained in:
Xi Yan 2025-03-15 14:24:46 -07:00
parent 82ec0d24f3
commit a197101635
2 changed files with 10 additions and 19 deletions

View file

@ -131,33 +131,24 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load()
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
start_index = start_index or 0
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)
start = next_page_token
if rows_in_page == -1:
if limit is None or limit == -1:
end = len(dataset_info.dataset_impl)
else:
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
end = min(start_index + limit, len(dataset_info.dataset_impl))
rows = dataset_info.dataset_impl[start:end]
rows = dataset_info.dataset_impl[start_index:end]
return IterrowsResponse(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
data=rows,
next_index=end if end < len(dataset_info.dataset_impl) else None,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:

View file

@ -84,7 +84,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
start_index = start_index or 0
if limit == -1:
if limit is None or limit == -1:
end = len(loaded_dataset)
else:
end = min(start_index + limit, len(loaded_dataset))
@ -93,7 +93,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
return IterrowsResponse(
data=rows,
next_index=end,
next_index=end if end < len(loaded_dataset) else None,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: