diff --git a/llama_stack/apis/dataset/dataset.py b/llama_stack/apis/dataset/dataset.py index 164e16be4..9a4f442e5 100644 --- a/llama_stack/apis/dataset/dataset.py +++ b/llama_stack/apis/dataset/dataset.py @@ -13,20 +13,59 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated -TDatasetRow = TypeVar("TDatasetRow") +# A sample (row) from raw dataset +TDatasetSample = TypeVar("TDatasetSample") @json_schema_type -class DatasetRow(BaseModel): ... +class DatasetSample(BaseModel): ... @json_schema_type -class DictSample(DatasetRow): +class DictSample(DatasetSample): data: Dict[str, Any] @json_schema_type -class Generation(BaseModel): ... +class ProcessedDictSample(DatasetSample): + data: Dict[str, Any] + preprocessed: Dict[str, Any] + prediction: Dict[str, Any] + postprocessed: Dict[str, Any] + + +# # A sample (row) after preprocessing the raw dataset +# TPreprocessedSample = TypeVar("TPreprocessedSample") + +# @json_schema_type +# class PreprocessedSample(BaseModel): ... + +# @json_schema_type +# class InferencePreprocessedSample(PreprocessedSample): +# # TODO: either keep it generic or specific to inference API +# # messages: List[Message] +# data: Dict[str, Any] + +# # A sample (row) from model prediction output +# TPredictionSample = TypeVar("TPredictionSample") + +# @json_schema_type +# class PredictionSample(BaseModel): ... + +# @json_schema_type +# class InferencePredictionSample(PredictionSample): +# data: Dict[str, Any] + + +# # A sample (row) from post-processed output +# TPostprocessedSample = TypeVar("TPostprocessedSample") + +# @json_schema_type +# class PostprocessedSample(BaseModel): ... + +# @json_schema_type +# class InferencePostprocessedSample(PredictionSample): +# data: Dict[str, Any] @json_schema_type @@ -70,16 +109,17 @@ DatasetDef = Annotated[ ] -class BaseDataset(ABC, Generic[TDatasetRow]): +class BaseDataset(ABC, Generic[TDatasetSample]): def __init__(self) -> None: self.type: str = self.__class__.__name__ + @property @abstractmethod - def __iter__(self) -> Iterator[TDatasetRow]: + def dataset_id(self) -> str: raise NotImplementedError() @abstractmethod - def load(self) -> None: + def __iter__(self) -> Iterator[TDatasetSample]: raise NotImplementedError() @abstractmethod @@ -90,6 +130,10 @@ class BaseDataset(ABC, Generic[TDatasetRow]): def __len__(self) -> int: raise NotImplementedError() + @abstractmethod + def load(self) -> None: + raise NotImplementedError() + class Datasets(Protocol): @webmethod(route="/datasets/create") diff --git a/llama_stack/apis/evals/evals.py b/llama_stack/apis/evals/evals.py index 629e68d32..53a2ff6df 100644 --- a/llama_stack/apis/evals/evals.py +++ b/llama_stack/apis/evals/evals.py @@ -4,10 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Protocol +from abc import ABC, abstractmethod +from typing import Dict, Generic, List, Protocol from llama_models.schema_utils import webmethod - from pydantic import BaseModel from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -22,19 +22,26 @@ class EvaluationJobLogStream(BaseModel): job_uuid: str -class EvaluateTaskConfig(BaseModel): - # num examples to evaluate, evaluate all if None - n_samples: Optional[int] = None - # model evaluation params - sampling_params: SamplingParams = SamplingParams() +@json_schema_type +class EvalResult(BaseModel): + """Evaluation result.""" + + metrics: Dict[str, str] + + +@json_schema_type +class SingleEvalResult(BaseModel): + """Single evaluation result.""" + + score_data: Dict[str, float] @json_schema_type class EvaluateResponse(BaseModel): """Scores for evaluation.""" - preprocess_output: GenerationOutput - metrics: Dict[str, str] + eval_result: EvalResult + formatted_report: Optional[str] = None @json_schema_type @@ -56,6 +63,75 @@ class EvaluationJobCreateResponse(BaseModel): job_uuid: str +@json_schema_type +class EvaluateTaskConfig(BaseModel): + # num examples to evaluate, evaluate all if None + n_samples: Optional[int] = None + # model evaluation params + sampling_params: SamplingParams = SamplingParams() + + +class BaseTask( + ABC, + Generic[ + TDatasetSample, + TPreprocessedSample, + TPredictionSample, + TPostprocessedSample, + TSingleEvalResult, + ], +): + """ + A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods. + Base class for all evaluation tasks. Each task needs to implement the following methods: + - F1: preprocess_sample(self) + - F2: postprocess_sample(self) + - F3: score_sample(self) + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._name = self.__class__.__name__ + + @abstractmethod + def preprocess_sample(self, sample: TDatasetSample) -> TPreprocessedSample: + raise NotImplementedError() + + @abstractmethod + def postprocess_sample(self, sample: TPredictionSample) -> TPostprocessedSample: + raise NotImplementedError() + + @abstractmethod + def score_sample( + self, sample: TPostprocessedSample, ground_truth: TPreprocessedSample + ): + raise NotImplementedError() + + @abstractmethod + def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult: + raise NotImplementedError() + + def preprocess( + self, dataset: BaseDataset[TDatasetSample] + ) -> List[TPreprocessedSample]: + return [self.preprocess_sample(sample) for sample in self.dataset] + + def postprocess( + self, generation: List[TPredictionSample] + ) -> List[TPostprocessedSample]: + return [self.postprocess_sample(sample) for sample in generation] + + def score( + self, + postprocessed: List[TPostprocessedSample], + preprocessed_dataset: List[TPreprocessedSample], + ) -> List[TSingleEvalResult]: + return [ + self.score_sample(sample, ground_truth) + for sample, ground_truth in zip(postprocessed, self.preprocessed_dataset) + ] + + class Evals(Protocol): @webmethod(route="/evals/run") async def run_evals( diff --git a/llama_stack/distribution/registry/datasets/__init__.py b/llama_stack/distribution/registry/datasets/__init__.py index 3a60d6a5e..f0636212a 100644 --- a/llama_stack/distribution/registry/datasets/__init__.py +++ b/llama_stack/distribution/registry/datasets/__init__.py @@ -5,19 +5,25 @@ # the root directory of this source tree. # TODO: make these import config based -# from .dataset import CustomDataset, HFDataset -# from .dataset_registry import DatasetRegistry +from llama_stack.apis.dataset import * # noqa: F403 +from .dataset import CustomDataset, HuggingfaceDataset +from .dataset_registry import DatasetRegistry -# DATASETS_REGISTRY = { -# "mmlu-simple-eval-en": CustomDataset( -# name="mmlu_eval", -# url="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", -# ), -# "hellaswag": HFDataset( -# name="hellaswag", -# url="hf://hellaswag?split=validation&trust_remote_code=True", -# ), -# } +DATASETS_REGISTRY = [ + CustomDataset( + config=CustomDatasetDef( + identifier="mmlu-simple-eval-en", + url="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", + ) + ), + HuggingfaceDataset( + config=HuggingfaceDatasetDef( + identifier="hellaswag", + dataset_name="hellaswag", + kwargs={"split": "validation", "trust_remote_code": True}, + ) + ), +] -# for k, v in DATASETS_REGISTRY.items(): -# DatasetRegistry.register(k, v) +for d in DATASETS_REGISTRY: + DatasetRegistry.register(d.dataset_id, d) diff --git a/llama_stack/distribution/registry/datasets/dataset.py b/llama_stack/distribution/registry/datasets/dataset.py index e3a2de399..87a01d311 100644 --- a/llama_stack/distribution/registry/datasets/dataset.py +++ b/llama_stack/distribution/registry/datasets/dataset.py @@ -16,17 +16,14 @@ class CustomDataset(BaseDataset[DictSample]): self.dataset = None self.index = 0 - def __iter__(self) -> Iterator[DictSample]: - return self + @property + def dataset_id(self) -> str: + return self.config.identifier - def __next__(self) -> DictSample: + def __iter__(self) -> Iterator[DictSample]: if not self.dataset: self.load() - if self.index >= len(self.dataset): - raise StopIteration - sample = DictSample(data=self.dataset[self.index]) - self.index += 1 - return sample + return (DictSample(data=x) for x in self.dataset) def __str__(self): return f"CustomDataset({self.config})" @@ -53,19 +50,15 @@ class HuggingfaceDataset(BaseDataset[DictSample]): super().__init__() self.config = config self.dataset = None - self.index = 0 + + @property + def dataset_id(self) -> str: + return self.config.identifier def __iter__(self) -> Iterator[DictSample]: - return self - - def __next__(self) -> DictSample: if not self.dataset: self.load() - if self.index >= len(self.dataset): - raise StopIteration - sample = DictSample(data=self.dataset[self.index]) - self.index += 1 - return sample + return (DictSample(data=x) for x in self.dataset) def __str__(self): return f"HuggingfaceDataset({self.config})" @@ -79,12 +72,3 @@ class HuggingfaceDataset(BaseDataset[DictSample]): if self.dataset: return self.dataset = load_dataset(self.config.dataset_name, **self.config.kwargs) - # parsed = urlparse(self.url) - - # if parsed.scheme != "hf": - # raise ValueError(f"Unknown HF dataset: {self.url}") - - # query = parse_qs(parsed.query) - # query = {k: v[0] for k, v in query.items()} - # path = parsed.netloc - # self.dataset = load_dataset(path, **query) diff --git a/llama_stack/providers/impls/meta_reference/evals/tasks/mmlu_task.py b/llama_stack/providers/impls/meta_reference/evals/tasks/mmlu_task.py index 673a95379..c5c9d9756 100644 --- a/llama_stack/providers/impls/meta_reference/evals/tasks/mmlu_task.py +++ b/llama_stack/providers/impls/meta_reference/evals/tasks/mmlu_task.py @@ -111,7 +111,14 @@ def normalize_extracted_answer(extracted_answer: str) -> str: ) -class MMLUTask(BaseTask): +class MMLUTask( + BaseTask[ + DictSample, + InferencePreprocessedSample, + InferencePredictionSample, + InferencePostprocessedSample, + ] +): """ MMLU Task. """ @@ -120,6 +127,7 @@ class MMLUTask(BaseTask): super().__init__(dataset, *args, **kwargs) def preprocess_sample(self, sample): + print(sample) content = QUERY_TEMPLATE_MULTICHOICE.format(**sample) return { "role": "user",