mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
add data structure to tasks
This commit is contained in:
parent
9816c9aae6
commit
ad18dc94ac
7 changed files with 100 additions and 168 deletions
|
@ -9,11 +9,12 @@ from enum import Enum
|
|||
from typing import Any, Dict, Generic, Iterator, Literal, Protocol, TypeVar, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
# A sample (row) from raw dataset
|
||||
# A sample (row) from dataset
|
||||
TDatasetSample = TypeVar("TDatasetSample")
|
||||
|
||||
|
||||
|
@ -26,46 +27,20 @@ class DictSample(DatasetSample):
|
|||
data: Dict[str, Any]
|
||||
|
||||
|
||||
# A sample (row) from evals intermediate dataset
|
||||
TProcessedSample = TypeVar("TProcessedSample")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProcessedDictSample(DatasetSample):
|
||||
data: Dict[str, Any]
|
||||
preprocessed: Dict[str, Any]
|
||||
prediction: Dict[str, Any]
|
||||
postprocessed: Dict[str, Any]
|
||||
class PredictionSample(BaseModel):
|
||||
completion_message: str
|
||||
|
||||
|
||||
# # A sample (row) after preprocessing the raw dataset
|
||||
# TPreprocessedSample = TypeVar("TPreprocessedSample")
|
||||
|
||||
# @json_schema_type
|
||||
# class PreprocessedSample(BaseModel): ...
|
||||
|
||||
# @json_schema_type
|
||||
# class InferencePreprocessedSample(PreprocessedSample):
|
||||
# # TODO: either keep it generic or specific to inference API
|
||||
# # messages: List[Message]
|
||||
# data: Dict[str, Any]
|
||||
|
||||
# # A sample (row) from model prediction output
|
||||
# TPredictionSample = TypeVar("TPredictionSample")
|
||||
|
||||
# @json_schema_type
|
||||
# class PredictionSample(BaseModel): ...
|
||||
|
||||
# @json_schema_type
|
||||
# class InferencePredictionSample(PredictionSample):
|
||||
# data: Dict[str, Any]
|
||||
|
||||
|
||||
# # A sample (row) from post-processed output
|
||||
# TPostprocessedSample = TypeVar("TPostprocessedSample")
|
||||
|
||||
# @json_schema_type
|
||||
# class PostprocessedSample(BaseModel): ...
|
||||
|
||||
# @json_schema_type
|
||||
# class InferencePostprocessedSample(PredictionSample):
|
||||
# data: Dict[str, Any]
|
||||
@json_schema_type
|
||||
class ProcessedDictSample(DictSample):
|
||||
preprocessed: Optional[Dict[str, Any]] = None
|
||||
prediction: Optional[PredictionSample] = None
|
||||
postprocessed: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -71,16 +71,7 @@ class EvaluateTaskConfig(BaseModel):
|
|||
sampling_params: SamplingParams = SamplingParams()
|
||||
|
||||
|
||||
class BaseTask(
|
||||
ABC,
|
||||
Generic[
|
||||
TDatasetSample,
|
||||
TPreprocessedSample,
|
||||
TPredictionSample,
|
||||
TPostprocessedSample,
|
||||
TSingleEvalResult,
|
||||
],
|
||||
):
|
||||
class BaseTask(ABC, Generic[TDatasetSample, TProcessedSample]):
|
||||
"""
|
||||
A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
|
||||
Base class for all evaluation tasks. Each task needs to implement the following methods:
|
||||
|
@ -94,17 +85,15 @@ class BaseTask(
|
|||
self._name = self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def preprocess_sample(self, sample: TDatasetSample) -> TPreprocessedSample:
|
||||
def preprocess_sample(self, sample: TDatasetSample) -> TProcessedSample:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def postprocess_sample(self, sample: TPredictionSample) -> TPostprocessedSample:
|
||||
def postprocess_sample(self, sample: TProcessedSample) -> TProcessedSample:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def score_sample(
|
||||
self, sample: TPostprocessedSample, ground_truth: TPreprocessedSample
|
||||
):
|
||||
def score_sample(self, sample: TProcessedSample) -> SingleEvalResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
|
@ -112,24 +101,15 @@ class BaseTask(
|
|||
raise NotImplementedError()
|
||||
|
||||
def preprocess(
|
||||
self, dataset: BaseDataset[TDatasetSample]
|
||||
) -> List[TPreprocessedSample]:
|
||||
return [self.preprocess_sample(sample) for sample in self.dataset]
|
||||
self, dataset: BaseDataset[TProcessedSample]
|
||||
) -> List[TProcessedSample]:
|
||||
return [self.preprocess_sample(sample) for sample in dataset]
|
||||
|
||||
def postprocess(
|
||||
self, generation: List[TPredictionSample]
|
||||
) -> List[TPostprocessedSample]:
|
||||
def postprocess(self, generation: List[TProcessedSample]) -> List[TProcessedSample]:
|
||||
return [self.postprocess_sample(sample) for sample in generation]
|
||||
|
||||
def score(
|
||||
self,
|
||||
postprocessed: List[TPostprocessedSample],
|
||||
preprocessed_dataset: List[TPreprocessedSample],
|
||||
) -> List[TSingleEvalResult]:
|
||||
return [
|
||||
self.score_sample(sample, ground_truth)
|
||||
for sample, ground_truth in zip(postprocessed, self.preprocessed_dataset)
|
||||
]
|
||||
def score(self, postprocessed: List[TProcessedSample]) -> List[SingleEvalResult]:
|
||||
return [self.score_sample(sample) for sample in postprocessed]
|
||||
|
||||
|
||||
class Evals(Protocol):
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseTask(ABC):
|
||||
"""
|
||||
A task represents a single evaluation benchmark, including it's dataset, preprocessing, postprocessing and scoring methods.
|
||||
Base class for all evaluation tasks. Each task needs to implement the following methods:
|
||||
- F1: preprocess_sample(self)
|
||||
- F2: postprocess_sample(self)
|
||||
- F3: score_sample(self)
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._name = self.__class__.__name__
|
||||
self.dataset = dataset
|
||||
|
||||
@abstractmethod
|
||||
def preprocess_sample(self, sample):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def postprocess_sample(self, sample):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def score_sample(self, sample, ground_truth):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def aggregate_results(self, eval_results):
|
||||
raise NotImplementedError()
|
||||
|
||||
def preprocess(self):
|
||||
return [self.preprocess_sample(sample) for sample in self.dataset]
|
||||
|
||||
def postprocess(self, generation):
|
||||
return [self.postprocess_sample(sample) for sample in generation]
|
||||
|
||||
def score(self, postprocessed):
|
||||
return [
|
||||
self.score_sample(sample, ground_truth)
|
||||
for sample, ground_truth in zip(postprocessed, self.dataset)
|
||||
]
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import AbstractSet, Dict
|
||||
|
||||
from .task import BaseTask
|
||||
from llama_stack.apis.evals import BaseTask
|
||||
|
||||
|
||||
class TaskRegistry:
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from llama_stack.apis.dataset import * # noqa: F403
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
||||
|
@ -42,30 +44,37 @@ class MetaReferenceEvalsImpl(Evals):
|
|||
dataset = DatasetRegistry.get_dataset(dataset)
|
||||
dataset.load()
|
||||
|
||||
task_impl = TaskRegistry.get_task(task)(dataset)
|
||||
x1 = task_impl.preprocess()
|
||||
task_impl = TaskRegistry.get_task(task)()
|
||||
preprocessed = task_impl.preprocess(dataset)
|
||||
|
||||
# TODO: replace w/ batch inference & async return eval job
|
||||
generation_outputs = []
|
||||
if eval_task_config is None:
|
||||
eval_task_config = EvaluateTaskConfig(n_samples=len(x1))
|
||||
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(x1):
|
||||
eval_task_config.n_samples = len(x1)
|
||||
eval_task_config = EvaluateTaskConfig(n_samples=len(preprocessed))
|
||||
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(
|
||||
preprocessed
|
||||
):
|
||||
eval_task_config.n_samples = len(preprocessed)
|
||||
|
||||
print(
|
||||
f"Eval generation start, generate on {eval_task_config.n_samples} samples"
|
||||
)
|
||||
|
||||
for msg in x1[: eval_task_config.n_samples]:
|
||||
print("generation for msg: ", msg)
|
||||
for sample in preprocessed[: eval_task_config.n_samples]:
|
||||
print("generation: ", sample)
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=model,
|
||||
messages=[msg],
|
||||
messages=sample.preprocessed["messages"],
|
||||
stream=False,
|
||||
)
|
||||
generation_outputs.append(response.completion_message.content)
|
||||
sample.prediction = PredictionSample(
|
||||
completion_message=response.completion_message.content
|
||||
)
|
||||
generation_outputs.append(sample)
|
||||
|
||||
x2 = task_impl.postprocess(generation_outputs)
|
||||
eval_results = task_impl.score(x2)
|
||||
eval_response = task_impl.aggregate_results(eval_results)
|
||||
return eval_response
|
||||
postprocessed = task_impl.postprocess(generation_outputs)
|
||||
eval_results = task_impl.score(postprocessed)
|
||||
aggr_result = task_impl.aggregate_results(eval_results)
|
||||
return EvaluateResponse(
|
||||
eval_result=aggr_result,
|
||||
)
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
import re
|
||||
|
||||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from llama_stack.distribution.registry.tasks.task import BaseTask
|
||||
|
||||
# from llama_stack.distribution.registry.tasks.task import BaseTask
|
||||
|
||||
QUERY_TEMPLATE_MULTICHOICE = """
|
||||
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
|
||||
|
@ -111,48 +112,60 @@ def normalize_extracted_answer(extracted_answer: str) -> str:
|
|||
)
|
||||
|
||||
|
||||
class MMLUTask(
|
||||
BaseTask[
|
||||
DictSample,
|
||||
InferencePreprocessedSample,
|
||||
InferencePredictionSample,
|
||||
InferencePostprocessedSample,
|
||||
]
|
||||
):
|
||||
class MMLUTask(BaseTask[DictSample, ProcessedDictSample]):
|
||||
"""
|
||||
MMLU Task.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, *args, **kwargs):
|
||||
super().__init__(dataset, *args, **kwargs)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def preprocess_sample(self, sample):
|
||||
print(sample)
|
||||
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample)
|
||||
return {
|
||||
"role": "user",
|
||||
"content": content,
|
||||
def preprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
|
||||
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample.data)
|
||||
preprocessed = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
],
|
||||
}
|
||||
processed_sample = ProcessedDictSample(
|
||||
data=sample.data,
|
||||
preprocessed=preprocessed,
|
||||
)
|
||||
return processed_sample
|
||||
|
||||
def postprocess_sample(self, sample):
|
||||
normalized = normalize_response(sample)
|
||||
return normalized
|
||||
def postprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
|
||||
if not sample.postprocessed:
|
||||
sample.postprocessed = {}
|
||||
sample.postprocessed["postprocessed"] = normalize_response(
|
||||
sample.prediction.completion_message
|
||||
)
|
||||
return sample
|
||||
|
||||
def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
|
||||
postprocessed_output = sample.postprocessed["postprocessed"]
|
||||
expected_answer = sample.data["Answer"]
|
||||
|
||||
def score_sample(self, sample, expected):
|
||||
extracted_answer = None
|
||||
for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
|
||||
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
|
||||
match = re.search(regex, sample)
|
||||
match = re.search(regex, postprocessed_output)
|
||||
if match:
|
||||
extracted_answer = normalize_extracted_answer(match.group(1))
|
||||
break
|
||||
score = (
|
||||
1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0
|
||||
)
|
||||
# TODO: generalize this into SingleEvalResult
|
||||
return score
|
||||
|
||||
def aggregate_results(self, eval_results):
|
||||
return EvaluateResponse(
|
||||
metrics={"score": str(sum(eval_results) / len(eval_results))}
|
||||
score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
|
||||
|
||||
return SingleEvalResult(
|
||||
score_data={
|
||||
"score": score,
|
||||
},
|
||||
)
|
||||
|
||||
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
||||
print("aggregate_results", eval_results)
|
||||
sum_score = sum([result.score_data["score"] for result in eval_results])
|
||||
|
||||
return EvalResult(metrics={"score": str(sum_score / len(eval_results))})
|
||||
|
|
|
@ -18,14 +18,18 @@ providers:
|
|||
provider_type: meta-reference
|
||||
config: {}
|
||||
inference:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
- provider_id: remote::tgi
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
url: http://127.0.0.1:5009
|
||||
# - provider_id: meta-reference
|
||||
# provider_type: meta-reference
|
||||
# config:
|
||||
# model: Llama3.1-8B-Instruct
|
||||
# quantization: null
|
||||
# torch_seed: null
|
||||
# max_seq_len: 4096
|
||||
# max_batch_size: 1
|
||||
safety:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue