From 0ee82571a8b1e412b42d341839176bb4005843d1 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 23 Oct 2024 17:30:10 -0700 Subject: [PATCH] refactor --- .../impls/meta_reference/datasetio/datasetio.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py index f38311a0c..77ae35538 100644 --- a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py @@ -61,18 +61,12 @@ class PandasDataframeDataset(BaseDataset): else: return self.df.iloc[idx].to_dict() - def validate_dataset_schema(self) -> None: - if self.df is None: - self.load() - - assert self.df is not None, "Dataset loading failed. Please check logs." - + def _validate_dataset_schema(self) -> None: + assert self.df is not None, "Dataset not loaded. Please call .load() first" # note that we will drop any columns in dataset that are not in the schema self.df = self.df[self.dataset_def.dataset_schema.keys()] - # check all columns in dataset schema are present assert len(self.df.columns) == len(self.dataset_def.dataset_schema) - # TODO: type checking against column types in dataset schema def load(self) -> None: @@ -106,6 +100,7 @@ class PandasDataframeDataset(BaseDataset): raise ValueError(f"Unsupported file type: {self.dataset_def.url}") self.df = df + self._validate_dataset_schema() class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): @@ -123,7 +118,6 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): dataset_def: DatasetDef, ) -> None: dataset_impl = PandasDataframeDataset(dataset_def) - dataset_impl.validate_dataset_schema() self.dataset_infos[dataset_def.identifier] = DatasetInfo( dataset_def=dataset_def, dataset_impl=dataset_impl,