diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index db52270a7..2fde7c3d0 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -21,14 +21,19 @@ DATASETS_PREFIX = "datasets:" def load_hf_dataset(dataset_def: Dataset): if dataset_def.metadata.get("path", None): - return hf_datasets.load_dataset(**dataset_def.metadata) + dataset = hf_datasets.load_dataset(**dataset_def.metadata) + else: + df = get_dataframe_from_url(dataset_def.url) - df = get_dataframe_from_url(dataset_def.url) + if df is None: + raise ValueError(f"Failed to load dataset from {dataset_def.url}") - if df is None: - raise ValueError(f"Failed to load dataset from {dataset_def.url}") + dataset = hf_datasets.Dataset.from_pandas(df) + + # drop columns not specified by schema + if dataset_def.dataset_schema: + dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys())) - dataset = hf_datasets.Dataset.from_pandas(df) return dataset