diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index 366af8602..e8811d233 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -17,7 +17,7 @@ class PaginatedRowsResult(BaseModel): # the rows obey the DatasetSchema for the given dataset rows: List[Dict[str, Any]] total_count: int - next_page_token: Optional[int] = None + next_page_token: Optional[str] = None class DatasetStore(Protocol): @@ -34,6 +34,6 @@ class DatasetIO(Protocol): self, dataset_id: str, rows_in_page: int, - page_token: Optional[int] = None, + page_token: Optional[str] = None, filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: ... diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 9160e1e13..e2b764d7f 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -50,11 +50,5 @@ class Datasets(Protocol): dataset_identifier: str, ) -> Optional[DatasetDefWithProvider]: ... - @webmethod(route="/datasets/delete") - async def delete_dataset( - self, - dataset_identifier: str, - ) -> None: ... - @webmethod(route="/datasets/list", method="GET") async def list_datasets(self) -> List[DatasetDefWithProvider]: ... diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 6656c34af..66f5d24e8 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -222,7 +222,3 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None: await self.register_object(dataset_def) - - async def delete_dataset(self, dataset_identifier: str) -> None: - # TODO: pass through for now - return diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py index 57bfefc31..c3309ca18 100644 --- a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py @@ -10,32 +10,21 @@ import pandas from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 -from dataclasses import dataclass - from llama_stack.providers.datatypes import DatasetsProtocolPrivate -from llama_stack.providers.utils.datasetio.dataset_utils import BaseDataset +from llama_stack.providers.utils.datasetio.dataset_utils import BaseDataset, DatasetInfo from .config import MetaReferenceDatasetIOConfig -@dataclass -class DatasetInfo: - dataset_def: DatasetDef - dataset_impl: BaseDataset - next_page_token: Optional[int] = None - - -class CustomDataset(BaseDataset): +class PandasDataframeDataset(BaseDataset): def __init__(self, dataset_def: DatasetDef, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.dataset_def = dataset_def - # TODO: validate dataset_def against schema self.df = None - self.load() def __len__(self) -> int: if self.df is None: - self.load() + raise ValueError("Dataset not loaded. Please call .load() first") return len(self.df) def __getitem__(self, idx): @@ -91,10 +80,11 @@ class MetaReferenceDatasetioImpl(DatasetIO, DatasetsProtocolPrivate): self, dataset_def: DatasetDef, ) -> None: + dataset_impl = PandasDataframeDataset(dataset_def) + dataset_impl.load() self.dataset_infos[dataset_def.identifier] = DatasetInfo( dataset_def=dataset_def, - dataset_impl=CustomDataset(dataset_def), - next_page_token=0, + dataset_impl=dataset_impl, ) async def list_datasets(self) -> List[DatasetDef]: @@ -104,23 +94,24 @@ class MetaReferenceDatasetioImpl(DatasetIO, DatasetsProtocolPrivate): self, dataset_id: str, rows_in_page: int, - page_token: Optional[int] = None, + page_token: Optional[str] = None, filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: dataset_info = self.dataset_infos.get(dataset_id) if page_token is None: - dataset_info.next_page_token = 0 + next_page_token = 0 + else: + next_page_token = int(page_token) if rows_in_page == -1: - rows = dataset_info.dataset_impl[dataset_info.next_page_token :] + rows = dataset_info.dataset_impl[next_page_token:] - start = dataset_info.next_page_token + start = next_page_token end = min(start + rows_in_page, len(dataset_info.dataset_impl)) rows = dataset_info.dataset_impl[start:end] - dataset_info.next_page_token = end return PaginatedRowsResult( rows=rows, total_count=len(rows), - next_page_token=dataset_info.next_page_token, + next_page_token=str(end), ) diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index 89a0a38f6..27e80ff57 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -19,13 +19,4 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.impls.meta_reference.datasetio.MetaReferenceDatasetIOConfig", api_dependencies=[], ), - remote_provider_spec( - api=Api.datasetio, - adapter=AdapterSpec( - adapter_type="sample", - pip_packages=[], - module="llama_stack.providers.adapters.datasetio.sample", - config_class="llama_stack.providers.adapters.datasetio.sample.SampleConfig", - ), - ), ] diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 0d8b6e49b..85235a64b 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -95,7 +95,7 @@ async def test_get_rows_paginated(datasetio_settings): assert isinstance(response.rows, list) assert len(response.rows) == 3 - assert response.next_page_token == 3 + assert response.next_page_token == "3" # iterate over all rows response = await datasetio_impl.get_rows_paginated( @@ -106,4 +106,4 @@ async def test_get_rows_paginated(datasetio_settings): assert isinstance(response.rows, list) assert len(response.rows) == 10 - assert response.next_page_token == 13 + assert response.next_page_token == "13" diff --git a/llama_stack/providers/utils/datasetio/dataset_utils.py b/llama_stack/providers/utils/datasetio/dataset_utils.py index 960004937..fd56d2424 100644 --- a/llama_stack/providers/utils/datasetio/dataset_utils.py +++ b/llama_stack/providers/utils/datasetio/dataset_utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from abc import ABC, abstractmethod +from dataclasses import dataclass +from llama_stack.apis.datasetio import * # noqa: F403 class BaseDataset(ABC): @@ -21,3 +23,9 @@ class BaseDataset(ABC): @abstractmethod def load(self): raise NotImplementedError() + + +@dataclass +class DatasetInfo: + dataset_def: DatasetDef + dataset_impl: BaseDataset