From 8b45d147df4519533e0fe4f8b38d2e03c7c4dbd8 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 12 Dec 2024 10:23:09 -0800 Subject: [PATCH] [/datasetio] drop columns not specified by dataset schema for huggingface provider (#611) # What does this PR do? **Why** - huggingface datasets could have extra unused columns, some of these columns (e.g. images) is unable to be casted as JSON over http requests for datasetio. - it is also inefficient to create a new dataset that's a subset of columns **Solution** - drop columns not specified by dataset schema ## Test Plan Tested with script: https://gist.github.com/yanxi0830/23be5725e0d82d79e24cc5dd1d21b571 ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- .../remote/datasetio/huggingface/huggingface.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index db52270a7..2fde7c3d0 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -21,14 +21,19 @@ DATASETS_PREFIX = "datasets:" def load_hf_dataset(dataset_def: Dataset): if dataset_def.metadata.get("path", None): - return hf_datasets.load_dataset(**dataset_def.metadata) + dataset = hf_datasets.load_dataset(**dataset_def.metadata) + else: + df = get_dataframe_from_url(dataset_def.url) - df = get_dataframe_from_url(dataset_def.url) + if df is None: + raise ValueError(f"Failed to load dataset from {dataset_def.url}") - if df is None: - raise ValueError(f"Failed to load dataset from {dataset_def.url}") + dataset = hf_datasets.Dataset.from_pandas(df) + + # drop columns not specified by schema + if dataset_def.dataset_schema: + dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys())) - dataset = hf_datasets.Dataset.from_pandas(df) return dataset