forked from phoenix-oss/llama-stack-mirror
feat(dataset api): (1.4/n) fix resolver signature mismatch (#1658)
# What does this PR do? - fix datasets api signature mis-match so that llama stack run can start [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` llama stack run ``` <img width="626" alt="image" src="https://github.com/user-attachments/assets/59072d1a-ccb6-453a-80e8-d87419896c41" /> [//]: # (## Documentation)
This commit is contained in:
parent
72ccdc19a8
commit
2c9d624910
10 changed files with 105 additions and 80 deletions
|
@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
import datasets as hf_datasets
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
|
@ -73,36 +73,27 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
await self.kvstore.delete(key=key)
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
loaded_dataset = load_hf_dataset(dataset_def)
|
||||
|
||||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
start_index = start_index or 0
|
||||
|
||||
if page_token is None or len(page_token) == 0:
|
||||
next_page_token = 0
|
||||
else:
|
||||
next_page_token = int(page_token)
|
||||
|
||||
start = next_page_token
|
||||
if rows_in_page == -1:
|
||||
if limit is None or limit == -1:
|
||||
end = len(loaded_dataset)
|
||||
else:
|
||||
end = min(start + rows_in_page, len(loaded_dataset))
|
||||
end = min(start_index + limit, len(loaded_dataset))
|
||||
|
||||
rows = [loaded_dataset[i] for i in range(start, end)]
|
||||
rows = [loaded_dataset[i] for i in range(start_index, end)]
|
||||
|
||||
return PaginatedRowsResult(
|
||||
rows=rows,
|
||||
total_count=len(rows),
|
||||
next_page_token=str(end),
|
||||
return IterrowsResponse(
|
||||
data=rows,
|
||||
next_index=end if end < len(loaded_dataset) else None,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue