add data structure to tasks

This commit is contained in:
Xi Yan 2024-10-10 21:33:13 -07:00
parent 9816c9aae6
commit ad18dc94ac
7 changed files with 100 additions and 168 deletions

View file

@ -9,11 +9,12 @@ from enum import Enum
from typing import Any, Dict, Generic, Iterator, Literal, Protocol, TypeVar, Union
from llama_models.schema_utils import json_schema_type, webmethod
from llama_models.llama3.api.datatypes import * # noqa: F403
from pydantic import BaseModel, Field
from typing_extensions import Annotated
# A sample (row) from raw dataset
# A sample (row) from dataset
TDatasetSample = TypeVar("TDatasetSample")
@ -26,46 +27,20 @@ class DictSample(DatasetSample):
data: Dict[str, Any]
# A sample (row) from evals intermediate dataset
TProcessedSample = TypeVar("TProcessedSample")
@json_schema_type
class ProcessedDictSample(DatasetSample):
data: Dict[str, Any]
preprocessed: Dict[str, Any]
prediction: Dict[str, Any]
postprocessed: Dict[str, Any]
class PredictionSample(BaseModel):
completion_message: str
# # A sample (row) after preprocessing the raw dataset
# TPreprocessedSample = TypeVar("TPreprocessedSample")
# @json_schema_type
# class PreprocessedSample(BaseModel): ...
# @json_schema_type
# class InferencePreprocessedSample(PreprocessedSample):
# # TODO: either keep it generic or specific to inference API
# # messages: List[Message]
# data: Dict[str, Any]
# # A sample (row) from model prediction output
# TPredictionSample = TypeVar("TPredictionSample")
# @json_schema_type
# class PredictionSample(BaseModel): ...
# @json_schema_type
# class InferencePredictionSample(PredictionSample):
# data: Dict[str, Any]
# # A sample (row) from post-processed output
# TPostprocessedSample = TypeVar("TPostprocessedSample")
# @json_schema_type
# class PostprocessedSample(BaseModel): ...
# @json_schema_type
# class InferencePostprocessedSample(PredictionSample):
# data: Dict[str, Any]
@json_schema_type
class ProcessedDictSample(DictSample):
preprocessed: Optional[Dict[str, Any]] = None
prediction: Optional[PredictionSample] = None
postprocessed: Optional[Dict[str, Any]] = None
@json_schema_type

View file

@ -71,16 +71,7 @@ class EvaluateTaskConfig(BaseModel):
sampling_params: SamplingParams = SamplingParams()
class BaseTask(
ABC,
Generic[
TDatasetSample,
TPreprocessedSample,
TPredictionSample,
TPostprocessedSample,
TSingleEvalResult,
],
):
class BaseTask(ABC, Generic[TDatasetSample, TProcessedSample]):
"""
A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
Base class for all evaluation tasks. Each task needs to implement the following methods:
@ -94,17 +85,15 @@ class BaseTask(
self._name = self.__class__.__name__
@abstractmethod
def preprocess_sample(self, sample: TDatasetSample) -> TPreprocessedSample:
def preprocess_sample(self, sample: TDatasetSample) -> TProcessedSample:
raise NotImplementedError()
@abstractmethod
def postprocess_sample(self, sample: TPredictionSample) -> TPostprocessedSample:
def postprocess_sample(self, sample: TProcessedSample) -> TProcessedSample:
raise NotImplementedError()
@abstractmethod
def score_sample(
self, sample: TPostprocessedSample, ground_truth: TPreprocessedSample
):
def score_sample(self, sample: TProcessedSample) -> SingleEvalResult:
raise NotImplementedError()
@abstractmethod
@ -112,24 +101,15 @@ class BaseTask(
raise NotImplementedError()
def preprocess(
self, dataset: BaseDataset[TDatasetSample]
) -> List[TPreprocessedSample]:
return [self.preprocess_sample(sample) for sample in self.dataset]
self, dataset: BaseDataset[TProcessedSample]
) -> List[TProcessedSample]:
return [self.preprocess_sample(sample) for sample in dataset]
def postprocess(
self, generation: List[TPredictionSample]
) -> List[TPostprocessedSample]:
def postprocess(self, generation: List[TProcessedSample]) -> List[TProcessedSample]:
return [self.postprocess_sample(sample) for sample in generation]
def score(
self,
postprocessed: List[TPostprocessedSample],
preprocessed_dataset: List[TPreprocessedSample],
) -> List[TSingleEvalResult]:
return [
self.score_sample(sample, ground_truth)
for sample, ground_truth in zip(postprocessed, self.preprocessed_dataset)
]
def score(self, postprocessed: List[TProcessedSample]) -> List[SingleEvalResult]:
return [self.score_sample(sample) for sample in postprocessed]
class Evals(Protocol):

View file

@ -1,49 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
class BaseTask(ABC):
"""
A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
Base class for all evaluation tasks. Each task needs to implement the following methods:
- F1: preprocess_sample(self)
- F2: postprocess_sample(self)
- F3: score_sample(self)
"""
def __init__(self, dataset, *args, **kwargs):
super().__init__(*args, **kwargs)
self._name = self.__class__.__name__
self.dataset = dataset
@abstractmethod
def preprocess_sample(self, sample):
raise NotImplementedError()
@abstractmethod
def postprocess_sample(self, sample):
raise NotImplementedError()
@abstractmethod
def score_sample(self, sample, ground_truth):
raise NotImplementedError()
@abstractmethod
def aggregate_results(self, eval_results):
raise NotImplementedError()
def preprocess(self):
return [self.preprocess_sample(sample) for sample in self.dataset]
def postprocess(self, generation):
return [self.postprocess_sample(sample) for sample in generation]
def score(self, postprocessed):
return [
self.score_sample(sample, ground_truth)
for sample, ground_truth in zip(postprocessed, self.dataset)
]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import AbstractSet, Dict
from .task import BaseTask
from llama_stack.apis.evals import BaseTask
class TaskRegistry:

View file

@ -6,6 +6,8 @@
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
@ -42,30 +44,37 @@ class MetaReferenceEvalsImpl(Evals):
dataset = DatasetRegistry.get_dataset(dataset)
dataset.load()
task_impl = TaskRegistry.get_task(task)(dataset)
x1 = task_impl.preprocess()
task_impl = TaskRegistry.get_task(task)()
preprocessed = task_impl.preprocess(dataset)
# TODO: replace w/ batch inference & async return eval job
generation_outputs = []
if eval_task_config is None:
eval_task_config = EvaluateTaskConfig(n_samples=len(x1))
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(x1):
eval_task_config.n_samples = len(x1)
eval_task_config = EvaluateTaskConfig(n_samples=len(preprocessed))
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(
preprocessed
):
eval_task_config.n_samples = len(preprocessed)
print(
f"Eval generation start, generate on {eval_task_config.n_samples} samples"
)
for msg in x1[: eval_task_config.n_samples]:
print("generation for msg: ", msg)
for sample in preprocessed[: eval_task_config.n_samples]:
print("generation: ", sample)
response = await self.inference_api.chat_completion(
model=model,
messages=[msg],
messages=sample.preprocessed["messages"],
stream=False,
)
generation_outputs.append(response.completion_message.content)
sample.prediction = PredictionSample(
completion_message=response.completion_message.content
)
generation_outputs.append(sample)
x2 = task_impl.postprocess(generation_outputs)
eval_results = task_impl.score(x2)
eval_response = task_impl.aggregate_results(eval_results)
return eval_response
postprocessed = task_impl.postprocess(generation_outputs)
eval_results = task_impl.score(postprocessed)
aggr_result = task_impl.aggregate_results(eval_results)
return EvaluateResponse(
eval_result=aggr_result,
)

View file

@ -6,7 +6,8 @@
import re
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.distribution.registry.tasks.task import BaseTask
# from llama_stack.distribution.registry.tasks.task import BaseTask
QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
@ -111,48 +112,60 @@ def normalize_extracted_answer(extracted_answer: str) -> str:
)
class MMLUTask(
BaseTask[
DictSample,
InferencePreprocessedSample,
InferencePredictionSample,
InferencePostprocessedSample,
]
):
class MMLUTask(BaseTask[DictSample, ProcessedDictSample]):
"""
MMLU Task.
"""
def __init__(self, dataset, *args, **kwargs):
super().__init__(dataset, *args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def preprocess_sample(self, sample):
print(sample)
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample)
return {
def preprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample.data)
preprocessed = {
"messages": [
{
"role": "user",
"content": content,
}
],
}
processed_sample = ProcessedDictSample(
data=sample.data,
preprocessed=preprocessed,
)
return processed_sample
def postprocess_sample(self, sample):
normalized = normalize_response(sample)
return normalized
def postprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
if not sample.postprocessed:
sample.postprocessed = {}
sample.postprocessed["postprocessed"] = normalize_response(
sample.prediction.completion_message
)
return sample
def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
postprocessed_output = sample.postprocessed["postprocessed"]
expected_answer = sample.data["Answer"]
def score_sample(self, sample, expected):
extracted_answer = None
for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
match = re.search(regex, sample)
match = re.search(regex, postprocessed_output)
if match:
extracted_answer = normalize_extracted_answer(match.group(1))
break
score = (
1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0
)
# TODO: generalize this into SingleEvalResult
return score
def aggregate_results(self, eval_results):
return EvaluateResponse(
metrics={"score": str(sum(eval_results) / len(eval_results))}
score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
return SingleEvalResult(
score_data={
"score": score,
},
)
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
print("aggregate_results", eval_results)
sum_score = sum([result.score_data["score"] for result in eval_results])
return EvalResult(metrics={"score": str(sum_score / len(eval_results))})

View file

@ -18,14 +18,18 @@ providers:
provider_type: meta-reference
config: {}
inference:
- provider_id: meta-reference
provider_type: meta-reference
- provider_id: remote::tgi
provider_type: remote::tgi
config:
model: Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
url: http://127.0.0.1:5009
# - provider_id: meta-reference
# provider_type: meta-reference
# config:
# model: Llama3.1-8B-Instruct
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
safety:
- provider_id: meta-reference
provider_type: meta-reference