forked from phoenix-oss/llama-stack-mirror
feat(dataset api): (1.4/n) fix resolver signature mismatch (#1658)
# What does this PR do? - fix datasets api signature mis-match so that llama stack run can start [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` llama stack run ``` <img width="626" alt="image" src="https://github.com/user-attachments/assets/59072d1a-ccb6-453a-80e8-d87419896c41" /> [//]: # (## Documentation)
This commit is contained in:
parent
72ccdc19a8
commit
2c9d624910
10 changed files with 105 additions and 80 deletions
|
@ -13,7 +13,7 @@ from urllib.parse import urlparse
|
|||
import pandas
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
|
@ -128,36 +128,27 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
await self.kvstore.delete(key=key)
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
dataset_info.dataset_impl.load()
|
||||
|
||||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
start_index = start_index or 0
|
||||
|
||||
if page_token is None or len(page_token) == 0:
|
||||
next_page_token = 0
|
||||
else:
|
||||
next_page_token = int(page_token)
|
||||
|
||||
start = next_page_token
|
||||
if rows_in_page == -1:
|
||||
if limit is None or limit == -1:
|
||||
end = len(dataset_info.dataset_impl)
|
||||
else:
|
||||
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
|
||||
end = min(start_index + limit, len(dataset_info.dataset_impl))
|
||||
|
||||
rows = dataset_info.dataset_impl[start:end]
|
||||
rows = dataset_info.dataset_impl[start_index:end]
|
||||
|
||||
return PaginatedRowsResult(
|
||||
rows=rows,
|
||||
total_count=len(rows),
|
||||
next_page_token=str(end),
|
||||
return IterrowsResponse(
|
||||
data=rows,
|
||||
next_index=end if end < len(dataset_info.dataset_impl) else None,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue