diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py index 404e5c887..e3dbbc461 100644 --- a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py @@ -48,8 +48,7 @@ class PandasDataframeDataset(BaseDataset): self.df = None def __len__(self) -> int: - if self.df is None: - raise ValueError("Dataset not loaded. Please call .load() first") + assert self.df is not None, "Dataset not loaded. Please call .load() first" return len(self.df) def __getitem__(self, idx): @@ -59,7 +58,7 @@ class PandasDataframeDataset(BaseDataset): return self.df.iloc[idx].to_dict() def load(self) -> None: - if self.df: + if self.df is not None: return # TODO: more robust support w/ data url @@ -106,7 +105,6 @@ class MetaReferenceDatasetioImpl(DatasetIO, DatasetsProtocolPrivate): dataset_def: DatasetDef, ) -> None: dataset_impl = PandasDataframeDataset(dataset_def) - dataset_impl.load() self.dataset_infos[dataset_def.identifier] = DatasetInfo( dataset_def=dataset_def, dataset_impl=dataset_impl, @@ -123,6 +121,8 @@ class MetaReferenceDatasetioImpl(DatasetIO, DatasetsProtocolPrivate): filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: dataset_info = self.dataset_infos.get(dataset_id) + dataset_info.dataset_impl.load() + if page_token is None: next_page_token = 0 else: