add data structure to tasks

This commit is contained in:
Xi Yan 2024-10-10 21:33:13 -07:00
parent 9816c9aae6
commit ad18dc94ac
7 changed files with 100 additions and 168 deletions

View file

@ -9,11 +9,12 @@ from enum import Enum
from typing import Any, Dict, Generic, Iterator, Literal, Protocol, TypeVar, Union from typing import Any, Dict, Generic, Iterator, Literal, Protocol, TypeVar, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from llama_models.llama3.api.datatypes import * # noqa: F403
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
# A sample (row) from raw dataset # A sample (row) from dataset
TDatasetSample = TypeVar("TDatasetSample") TDatasetSample = TypeVar("TDatasetSample")
@ -26,46 +27,20 @@ class DictSample(DatasetSample):
data: Dict[str, Any] data: Dict[str, Any]
# A sample (row) from evals intermediate dataset
TProcessedSample = TypeVar("TProcessedSample")
@json_schema_type @json_schema_type
class ProcessedDictSample(DatasetSample): class PredictionSample(BaseModel):
data: Dict[str, Any] completion_message: str
preprocessed: Dict[str, Any]
prediction: Dict[str, Any]
postprocessed: Dict[str, Any]
# # A sample (row) after preprocessing the raw dataset @json_schema_type
# TPreprocessedSample = TypeVar("TPreprocessedSample") class ProcessedDictSample(DictSample):
preprocessed: Optional[Dict[str, Any]] = None
# @json_schema_type prediction: Optional[PredictionSample] = None
# class PreprocessedSample(BaseModel): ... postprocessed: Optional[Dict[str, Any]] = None
# @json_schema_type
# class InferencePreprocessedSample(PreprocessedSample):
# # TODO: either keep it generic or specific to inference API
# # messages: List[Message]
# data: Dict[str, Any]
# # A sample (row) from model prediction output
# TPredictionSample = TypeVar("TPredictionSample")
# @json_schema_type
# class PredictionSample(BaseModel): ...
# @json_schema_type
# class InferencePredictionSample(PredictionSample):
# data: Dict[str, Any]
# # A sample (row) from post-processed output
# TPostprocessedSample = TypeVar("TPostprocessedSample")
# @json_schema_type
# class PostprocessedSample(BaseModel): ...
# @json_schema_type
# class InferencePostprocessedSample(PredictionSample):
# data: Dict[str, Any]
@json_schema_type @json_schema_type

View file

@ -71,16 +71,7 @@ class EvaluateTaskConfig(BaseModel):
sampling_params: SamplingParams = SamplingParams() sampling_params: SamplingParams = SamplingParams()
class BaseTask( class BaseTask(ABC, Generic[TDatasetSample, TProcessedSample]):
ABC,
Generic[
TDatasetSample,
TPreprocessedSample,
TPredictionSample,
TPostprocessedSample,
TSingleEvalResult,
],
):
""" """
A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods. A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
Base class for all evaluation tasks. Each task needs to implement the following methods: Base class for all evaluation tasks. Each task needs to implement the following methods:
@ -94,17 +85,15 @@ class BaseTask(
self._name = self.__class__.__name__ self._name = self.__class__.__name__
@abstractmethod @abstractmethod
def preprocess_sample(self, sample: TDatasetSample) -> TPreprocessedSample: def preprocess_sample(self, sample: TDatasetSample) -> TProcessedSample:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def postprocess_sample(self, sample: TPredictionSample) -> TPostprocessedSample: def postprocess_sample(self, sample: TProcessedSample) -> TProcessedSample:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def score_sample( def score_sample(self, sample: TProcessedSample) -> SingleEvalResult:
self, sample: TPostprocessedSample, ground_truth: TPreprocessedSample
):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
@ -112,24 +101,15 @@ class BaseTask(
raise NotImplementedError() raise NotImplementedError()
def preprocess( def preprocess(
self, dataset: BaseDataset[TDatasetSample] self, dataset: BaseDataset[TProcessedSample]
) -> List[TPreprocessedSample]: ) -> List[TProcessedSample]:
return [self.preprocess_sample(sample) for sample in self.dataset] return [self.preprocess_sample(sample) for sample in dataset]
def postprocess( def postprocess(self, generation: List[TProcessedSample]) -> List[TProcessedSample]:
self, generation: List[TPredictionSample]
) -> List[TPostprocessedSample]:
return [self.postprocess_sample(sample) for sample in generation] return [self.postprocess_sample(sample) for sample in generation]
def score( def score(self, postprocessed: List[TProcessedSample]) -> List[SingleEvalResult]:
self, return [self.score_sample(sample) for sample in postprocessed]
postprocessed: List[TPostprocessedSample],
preprocessed_dataset: List[TPreprocessedSample],
) -> List[TSingleEvalResult]:
return [
self.score_sample(sample, ground_truth)
for sample, ground_truth in zip(postprocessed, self.preprocessed_dataset)
]
class Evals(Protocol): class Evals(Protocol):

View file

@ -1,49 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
class BaseTask(ABC):
"""
A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
Base class for all evaluation tasks. Each task needs to implement the following methods:
- F1: preprocess_sample(self)
- F2: postprocess_sample(self)
- F3: score_sample(self)
"""
def __init__(self, dataset, *args, **kwargs):
super().__init__(*args, **kwargs)
self._name = self.__class__.__name__
self.dataset = dataset
@abstractmethod
def preprocess_sample(self, sample):
raise NotImplementedError()
@abstractmethod
def postprocess_sample(self, sample):
raise NotImplementedError()
@abstractmethod
def score_sample(self, sample, ground_truth):
raise NotImplementedError()
@abstractmethod
def aggregate_results(self, eval_results):
raise NotImplementedError()
def preprocess(self):
return [self.preprocess_sample(sample) for sample in self.dataset]
def postprocess(self, generation):
return [self.postprocess_sample(sample) for sample in generation]
def score(self, postprocessed):
return [
self.score_sample(sample, ground_truth)
for sample, ground_truth in zip(postprocessed, self.dataset)
]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import AbstractSet, Dict from typing import AbstractSet, Dict
from .task import BaseTask from llama_stack.apis.evals import BaseTask
class TaskRegistry: class TaskRegistry:

View file

@ -6,6 +6,8 @@
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
@ -42,30 +44,37 @@ class MetaReferenceEvalsImpl(Evals):
dataset = DatasetRegistry.get_dataset(dataset) dataset = DatasetRegistry.get_dataset(dataset)
dataset.load() dataset.load()
task_impl = TaskRegistry.get_task(task)(dataset) task_impl = TaskRegistry.get_task(task)()
x1 = task_impl.preprocess() preprocessed = task_impl.preprocess(dataset)
# TODO: replace w/ batch inference & async return eval job # TODO: replace w/ batch inference & async return eval job
generation_outputs = [] generation_outputs = []
if eval_task_config is None: if eval_task_config is None:
eval_task_config = EvaluateTaskConfig(n_samples=len(x1)) eval_task_config = EvaluateTaskConfig(n_samples=len(preprocessed))
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(x1): if eval_task_config.n_samples is None or eval_task_config.n_samples > len(
eval_task_config.n_samples = len(x1) preprocessed
):
eval_task_config.n_samples = len(preprocessed)
print( print(
f"Eval generation start, generate on {eval_task_config.n_samples} samples" f"Eval generation start, generate on {eval_task_config.n_samples} samples"
) )
for msg in x1[: eval_task_config.n_samples]: for sample in preprocessed[: eval_task_config.n_samples]:
print("generation for msg: ", msg) print("generation: ", sample)
response = await self.inference_api.chat_completion( response = await self.inference_api.chat_completion(
model=model, model=model,
messages=[msg], messages=sample.preprocessed["messages"],
stream=False, stream=False,
) )
generation_outputs.append(response.completion_message.content) sample.prediction = PredictionSample(
completion_message=response.completion_message.content
)
generation_outputs.append(sample)
x2 = task_impl.postprocess(generation_outputs) postprocessed = task_impl.postprocess(generation_outputs)
eval_results = task_impl.score(x2) eval_results = task_impl.score(postprocessed)
eval_response = task_impl.aggregate_results(eval_results) aggr_result = task_impl.aggregate_results(eval_results)
return eval_response return EvaluateResponse(
eval_result=aggr_result,
)

View file

@ -6,7 +6,8 @@
import re import re
from llama_stack.apis.evals import * # noqa: F403 from llama_stack.apis.evals import * # noqa: F403
from llama_stack.distribution.registry.tasks.task import BaseTask
# from llama_stack.distribution.registry.tasks.task import BaseTask
QUERY_TEMPLATE_MULTICHOICE = """ QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
@ -111,48 +112,60 @@ def normalize_extracted_answer(extracted_answer: str) -> str:
) )
class MMLUTask( class MMLUTask(BaseTask[DictSample, ProcessedDictSample]):
BaseTask[
DictSample,
InferencePreprocessedSample,
InferencePredictionSample,
InferencePostprocessedSample,
]
):
""" """
MMLU Task. MMLU Task.
""" """
def __init__(self, dataset, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(dataset, *args, **kwargs) super().__init__(*args, **kwargs)
def preprocess_sample(self, sample): def preprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
print(sample) content = QUERY_TEMPLATE_MULTICHOICE.format(**sample.data)
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample) preprocessed = {
return { "messages": [
{
"role": "user", "role": "user",
"content": content, "content": content,
} }
],
}
processed_sample = ProcessedDictSample(
data=sample.data,
preprocessed=preprocessed,
)
return processed_sample
def postprocess_sample(self, sample): def postprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
normalized = normalize_response(sample) if not sample.postprocessed:
return normalized sample.postprocessed = {}
sample.postprocessed["postprocessed"] = normalize_response(
sample.prediction.completion_message
)
return sample
def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
postprocessed_output = sample.postprocessed["postprocessed"]
expected_answer = sample.data["Answer"]
def score_sample(self, sample, expected):
extracted_answer = None extracted_answer = None
for answer_regex in MULTILINGUAL_ANSWER_REGEXES: for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex) regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
match = re.search(regex, sample) match = re.search(regex, postprocessed_output)
if match: if match:
extracted_answer = normalize_extracted_answer(match.group(1)) extracted_answer = normalize_extracted_answer(match.group(1))
break break
score = (
1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0
)
# TODO: generalize this into SingleEvalResult
return score
def aggregate_results(self, eval_results): score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
return EvaluateResponse(
metrics={"score": str(sum(eval_results) / len(eval_results))} return SingleEvalResult(
score_data={
"score": score,
},
) )
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
print("aggregate_results", eval_results)
sum_score = sum([result.score_data["score"] for result in eval_results])
return EvalResult(metrics={"score": str(sum_score / len(eval_results))})

View file

@ -18,14 +18,18 @@ providers:
provider_type: meta-reference provider_type: meta-reference
config: {} config: {}
inference: inference:
- provider_id: meta-reference - provider_id: remote::tgi
provider_type: meta-reference provider_type: remote::tgi
config: config:
model: Llama3.1-8B-Instruct url: http://127.0.0.1:5009
quantization: null # - provider_id: meta-reference
torch_seed: null # provider_type: meta-reference
max_seq_len: 4096 # config:
max_batch_size: 1 # model: Llama3.1-8B-Instruct
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
safety: safety:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: meta-reference provider_type: meta-reference