diff --git a/llama_stack/providers/inline/huggingface/datasetio/dataset_defs/__init__.py b/llama_stack/providers/inline/huggingface/datasetio/benchmarks/__init__.py similarity index 77% rename from llama_stack/providers/inline/huggingface/datasetio/dataset_defs/__init__.py rename to llama_stack/providers/inline/huggingface/datasetio/benchmarks/__init__.py index 756f351d8..c608f6fff 100644 --- a/llama_stack/providers/inline/huggingface/datasetio/dataset_defs/__init__.py +++ b/llama_stack/providers/inline/huggingface/datasetio/benchmarks/__init__.py @@ -3,3 +3,4 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from .llamastack_mmlu import llamastack_mmlu # noqa: F401 diff --git a/llama_stack/providers/inline/huggingface/datasetio/dataset_defs/llamastack_mmlu.py b/llama_stack/providers/inline/huggingface/datasetio/benchmarks/llamastack_mmlu.py similarity index 100% rename from llama_stack/providers/inline/huggingface/datasetio/dataset_defs/llamastack_mmlu.py rename to llama_stack/providers/inline/huggingface/datasetio/benchmarks/llamastack_mmlu.py diff --git a/llama_stack/providers/inline/huggingface/datasetio/huggingface.py b/llama_stack/providers/inline/huggingface/datasetio/huggingface.py index 8958f6161..fadd54209 100644 --- a/llama_stack/providers/inline/huggingface/datasetio/huggingface.py +++ b/llama_stack/providers/inline/huggingface/datasetio/huggingface.py @@ -8,24 +8,25 @@ from typing import List, Optional from llama_stack.apis.datasetio import * # noqa: F403 -from datasets import Dataset, load_dataset +import datasets as hf_datasets from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url +from .benchmarks import llamastack_mmlu + from .config import HuggingfaceDatasetIOConfig -from .dataset_defs.llamastack_mmlu import llamastack_mmlu def load_hf_dataset(dataset_def: DatasetDef): if dataset_def.metadata.get("path", None): - return load_dataset(**dataset_def.metadata) + return hf_datasets.load_dataset(**dataset_def.metadata) df = get_dataframe_from_url(dataset_def.url) if df is None: raise ValueError(f"Failed to load dataset from {dataset_def.url}") - dataset = Dataset.from_pandas(df) + dataset = hf_datasets.Dataset.from_pandas(df) return dataset @@ -37,8 +38,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): async def initialize(self) -> None: # pre-registered benchmark datasets - self.pre_registered_datasets = [llamastack_mmlu] - self.dataset_infos = {x.identifier: x for x in self.pre_registered_datasets} + pre_registered_datasets = [llamastack_mmlu] + self.dataset_infos = {x.identifier: x for x in pre_registered_datasets} async def shutdown(self) -> None: ...