From b946afddc0b387041a2962f71d36bb918c45854e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 10:28:51 -0800 Subject: [PATCH] datasetdef files --- .../datasetio/dataset_defs/__init__.py | 5 +++++ .../datasetio/dataset_defs/llamastack_mmlu.py | 21 +++++++++++++++++++ .../huggingface/datasetio/huggingface.py | 19 +++-------------- 3 files changed, 29 insertions(+), 16 deletions(-) create mode 100644 llama_stack/providers/inline/huggingface/datasetio/dataset_defs/__init__.py create mode 100644 llama_stack/providers/inline/huggingface/datasetio/dataset_defs/llamastack_mmlu.py diff --git a/llama_stack/providers/inline/huggingface/datasetio/dataset_defs/__init__.py b/llama_stack/providers/inline/huggingface/datasetio/dataset_defs/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/huggingface/datasetio/dataset_defs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/huggingface/datasetio/dataset_defs/llamastack_mmlu.py b/llama_stack/providers/inline/huggingface/datasetio/dataset_defs/llamastack_mmlu.py new file mode 100644 index 000000000..396344144 --- /dev/null +++ b/llama_stack/providers/inline/huggingface/datasetio/dataset_defs/llamastack_mmlu.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.llama3.api.datatypes import URL +from llama_stack.apis.common.type_system import StringType +from llama_stack.apis.datasetio import DatasetDef + + +llamastack_mmlu = DatasetDef( + identifier="llamastack_mmlu", + url=URL(uri="https://huggingface.co/datasets/yanxi0830/ls-mmlu"), + dataset_schema={ + "expected_answer": StringType(), + "input_query": StringType(), + "generated_answer": StringType(), + }, + metadata={"path": "yanxi0830/ls-mmlu", "split": "train"}, +) diff --git a/llama_stack/providers/inline/huggingface/datasetio/huggingface.py b/llama_stack/providers/inline/huggingface/datasetio/huggingface.py index e17850841..3b8c3049c 100644 --- a/llama_stack/providers/inline/huggingface/datasetio/huggingface.py +++ b/llama_stack/providers/inline/huggingface/datasetio/huggingface.py @@ -5,16 +5,13 @@ # the root directory of this source tree. from typing import List, Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 - from llama_stack.apis.datasetio import * # noqa: F403 from datasets import load_dataset - -from llama_stack.apis.common.type_system import StringType from llama_stack.providers.datatypes import DatasetsProtocolPrivate from .config import HuggingfaceDatasetIOConfig +from .dataset_defs.llamastack_mmlu import llamastack_mmlu class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): @@ -25,18 +22,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): async def initialize(self) -> None: # pre-registered benchmark datasets - self.dataset_infos = { - "mmlu": DatasetDef( - identifier="mmlu", - url=URL(uri="https://huggingface.co/datasets/yanxi0830/ls-mmlu"), - dataset_schema={ - "expected_answer": StringType(), - "input_query": StringType(), - "generated_answer": StringType(), - }, - metadata={"path": "yanxi0830/ls-mmlu", "split": "train"}, - ) - } + self.pre_registered_datasets = [llamastack_mmlu] + self.dataset_infos = {x.identifier: x for x in self.pre_registered_datasets} async def shutdown(self) -> None: ...