From 99ed1425fc4db16973fc6224e22caeeb9f2b19dc Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 10 Oct 2024 17:19:18 -0700 Subject: [PATCH] add dataset datatypes --- llama_stack/apis/dataset/dataset.py | 97 +++++++++++++++---- llama_stack/apis/evals/evals.py | 1 + .../registry/datasets/__init__.py | 28 +++--- .../distribution/registry/datasets/dataset.py | 94 +++++++++++------- .../registry/datasets/dataset_registry.py | 2 +- 5 files changed, 155 insertions(+), 67 deletions(-) diff --git a/llama_stack/apis/dataset/dataset.py b/llama_stack/apis/dataset/dataset.py index 8ab135b6a..164e16be4 100644 --- a/llama_stack/apis/dataset/dataset.py +++ b/llama_stack/apis/dataset/dataset.py @@ -4,46 +4,105 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional, Protocol - -from llama_models.llama3.api.datatypes import URL +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, Generic, Iterator, Literal, Protocol, TypeVar, Union from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +TDatasetRow = TypeVar("TDatasetRow") @json_schema_type -class TrainEvalDataset(BaseModel): - """Dataset to be used for training or evaluating language models.""" - - # unique identifier associated with the dataset - dataset_id: str - content_url: URL - metadata: Optional[Dict[str, Any]] = None +class DatasetRow(BaseModel): ... @json_schema_type -class CreateDatasetRequest(BaseModel): - """Request to create a dataset.""" +class DictSample(DatasetRow): + data: Dict[str, Any] - uuid: str - dataset: TrainEvalDataset + +@json_schema_type +class Generation(BaseModel): ... + + +@json_schema_type +class DatasetType(Enum): + custom = "custom" + huggingface = "huggingface" + + +@json_schema_type +class HuggingfaceDatasetDef(BaseModel): + type: Literal[DatasetType.huggingface.value] = DatasetType.huggingface.value + identifier: str = Field( + description="A unique name for the dataset", + ) + dataset_name: str = Field( + description="The name of the dataset into HF (e.g. hellawag)", + ) + kwargs: Dict[str, Any] = Field( + description="Any additional arguments to get Huggingface (e.g. split, trust_remote_code)", + default_factory=dict, + ) + + +@json_schema_type +class CustomDatasetDef(BaseModel): + type: Literal[DatasetType.custom.value] = DatasetType.custom.value + identifier: str = Field( + description="A unique name for the dataset", + ) + url: str = Field( + description="The URL to the dataset", + ) + + +DatasetDef = Annotated[ + Union[ + HuggingfaceDatasetDef, + CustomDatasetDef, + ], + Field(discriminator="type"), +] + + +class BaseDataset(ABC, Generic[TDatasetRow]): + def __init__(self) -> None: + self.type: str = self.__class__.__name__ + + @abstractmethod + def __iter__(self) -> Iterator[TDatasetRow]: + raise NotImplementedError() + + @abstractmethod + def load(self) -> None: + raise NotImplementedError() + + @abstractmethod + def __str__(self) -> str: + raise NotImplementedError() + + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError() class Datasets(Protocol): @webmethod(route="/datasets/create") def create_dataset( self, - uuid: str, - dataset: TrainEvalDataset, + dataset: DatasetDef, ) -> None: ... @webmethod(route="/datasets/get") def get_dataset( self, - dataset_uuid: str, - ) -> TrainEvalDataset: ... + dataset_identifier: str, + ) -> DatasetDef: ... @webmethod(route="/datasets/delete") def delete_dataset( diff --git a/llama_stack/apis/evals/evals.py b/llama_stack/apis/evals/evals.py index dbb1348a5..629e68d32 100644 --- a/llama_stack/apis/evals/evals.py +++ b/llama_stack/apis/evals/evals.py @@ -33,6 +33,7 @@ class EvaluateTaskConfig(BaseModel): class EvaluateResponse(BaseModel): """Scores for evaluation.""" + preprocess_output: GenerationOutput metrics: Dict[str, str] diff --git a/llama_stack/distribution/registry/datasets/__init__.py b/llama_stack/distribution/registry/datasets/__init__.py index 0b7a84395..3a60d6a5e 100644 --- a/llama_stack/distribution/registry/datasets/__init__.py +++ b/llama_stack/distribution/registry/datasets/__init__.py @@ -5,19 +5,19 @@ # the root directory of this source tree. # TODO: make these import config based -from .dataset import CustomDataset, HFDataset -from .dataset_registry import DatasetRegistry +# from .dataset import CustomDataset, HFDataset +# from .dataset_registry import DatasetRegistry -DATASETS_REGISTRY = { - "mmlu-simple-eval-en": CustomDataset( - name="mmlu_eval", - url="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", - ), - "hellaswag": HFDataset( - name="hellaswag", - url="hf://hellaswag?split=validation&trust_remote_code=True", - ), -} +# DATASETS_REGISTRY = { +# "mmlu-simple-eval-en": CustomDataset( +# name="mmlu_eval", +# url="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", +# ), +# "hellaswag": HFDataset( +# name="hellaswag", +# url="hf://hellaswag?split=validation&trust_remote_code=True", +# ), +# } -for k, v in DATASETS_REGISTRY.items(): - DatasetRegistry.register(k, v) +# for k, v in DATASETS_REGISTRY.items(): +# DatasetRegistry.register(k, v) diff --git a/llama_stack/distribution/registry/datasets/dataset.py b/llama_stack/distribution/registry/datasets/dataset.py index 1a16a5c51..e3a2de399 100644 --- a/llama_stack/distribution/registry/datasets/dataset.py +++ b/llama_stack/distribution/registry/datasets/dataset.py @@ -3,60 +3,88 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -from abc import ABC, abstractmethod -from urllib.parse import parse_qs, urlparse - import pandas from datasets import Dataset, load_dataset +from llama_stack.apis.dataset import * # noqa: F403 -class BaseDataset(ABC): - def __init__(self, name: str): + +class CustomDataset(BaseDataset[DictSample]): + def __init__(self, config: CustomDatasetDef) -> None: + super().__init__() + self.config = config self.dataset = None - self.dataset_id = name - self.type = self.__class__.__name__ + self.index = 0 - def __iter__(self): - return iter(self.dataset) + def __iter__(self) -> Iterator[DictSample]: + return self - @abstractmethod - def load(self): - pass + def __next__(self) -> DictSample: + if not self.dataset: + self.load() + if self.index >= len(self.dataset): + raise StopIteration + sample = DictSample(data=self.dataset[self.index]) + self.index += 1 + return sample + def __str__(self): + return f"CustomDataset({self.config})" -class CustomDataset(BaseDataset): - def __init__(self, name, url): - super().__init__(name) - self.url = url + def __len__(self): + if not self.dataset: + self.load() + return len(self.dataset) def load(self): if self.dataset: return # TODO: better support w/ data url - if self.url.endswith(".csv"): - df = pandas.read_csv(self.url) - elif self.url.endswith(".xlsx"): - df = pandas.read_excel(self.url) + if self.config.url.endswith(".csv"): + df = pandas.read_csv(self.config.url) + elif self.config.url.endswith(".xlsx"): + df = pandas.read_excel(self.config.url) self.dataset = Dataset.from_pandas(df) -class HFDataset(BaseDataset): - def __init__(self, name, url): - super().__init__(name) - self.url = url +class HuggingfaceDataset(BaseDataset[DictSample]): + def __init__(self, config: HuggingfaceDatasetDef): + super().__init__() + self.config = config + self.dataset = None + self.index = 0 + + def __iter__(self) -> Iterator[DictSample]: + return self + + def __next__(self) -> DictSample: + if not self.dataset: + self.load() + if self.index >= len(self.dataset): + raise StopIteration + sample = DictSample(data=self.dataset[self.index]) + self.index += 1 + return sample + + def __str__(self): + return f"HuggingfaceDataset({self.config})" + + def __len__(self): + if not self.dataset: + self.load() + return len(self.dataset) def load(self): if self.dataset: return + self.dataset = load_dataset(self.config.dataset_name, **self.config.kwargs) + # parsed = urlparse(self.url) - parsed = urlparse(self.url) + # if parsed.scheme != "hf": + # raise ValueError(f"Unknown HF dataset: {self.url}") - if parsed.scheme != "hf": - raise ValueError(f"Unknown HF dataset: {self.url}") - - query = parse_qs(parsed.query) - query = {k: v[0] for k, v in query.items()} - path = parsed.netloc - self.dataset = load_dataset(path, **query) + # query = parse_qs(parsed.query) + # query = {k: v[0] for k, v in query.items()} + # path = parsed.netloc + # self.dataset = load_dataset(path, **query) diff --git a/llama_stack/distribution/registry/datasets/dataset_registry.py b/llama_stack/distribution/registry/datasets/dataset_registry.py index 9ddaa8bb7..8e9b22266 100644 --- a/llama_stack/distribution/registry/datasets/dataset_registry.py +++ b/llama_stack/distribution/registry/datasets/dataset_registry.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from typing import AbstractSet, Dict -from .dataset import BaseDataset +from llama_stack.apis.dataset import BaseDataset class DatasetRegistry: