add data structure to tasks

This commit is contained in:
Xi Yan 2024-10-10 21:33:13 -07:00
parent 9816c9aae6
commit ad18dc94ac
7 changed files with 100 additions and 168 deletions

View file

@ -6,6 +6,8 @@
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
@ -42,30 +44,37 @@ class MetaReferenceEvalsImpl(Evals):
dataset = DatasetRegistry.get_dataset(dataset)
dataset.load()
task_impl = TaskRegistry.get_task(task)(dataset)
x1 = task_impl.preprocess()
task_impl = TaskRegistry.get_task(task)()
preprocessed = task_impl.preprocess(dataset)
# TODO: replace w/ batch inference & async return eval job
generation_outputs = []
if eval_task_config is None:
eval_task_config = EvaluateTaskConfig(n_samples=len(x1))
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(x1):
eval_task_config.n_samples = len(x1)
eval_task_config = EvaluateTaskConfig(n_samples=len(preprocessed))
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(
preprocessed
):
eval_task_config.n_samples = len(preprocessed)
print(
f"Eval generation start, generate on {eval_task_config.n_samples} samples"
)
for msg in x1[: eval_task_config.n_samples]:
print("generation for msg: ", msg)
for sample in preprocessed[: eval_task_config.n_samples]:
print("generation: ", sample)
response = await self.inference_api.chat_completion(
model=model,
messages=[msg],
messages=sample.preprocessed["messages"],
stream=False,
)
generation_outputs.append(response.completion_message.content)
sample.prediction = PredictionSample(
completion_message=response.completion_message.content
)
generation_outputs.append(sample)
x2 = task_impl.postprocess(generation_outputs)
eval_results = task_impl.score(x2)
eval_response = task_impl.aggregate_results(eval_results)
return eval_response
postprocessed = task_impl.postprocess(generation_outputs)
eval_results = task_impl.score(postprocessed)
aggr_result = task_impl.aggregate_results(eval_results)
return EvaluateResponse(
eval_result=aggr_result,
)

View file

@ -6,7 +6,8 @@
import re
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.distribution.registry.tasks.task import BaseTask
# from llama_stack.distribution.registry.tasks.task import BaseTask
QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
@ -111,48 +112,60 @@ def normalize_extracted_answer(extracted_answer: str) -> str:
)
class MMLUTask(
BaseTask[
DictSample,
InferencePreprocessedSample,
InferencePredictionSample,
InferencePostprocessedSample,
]
):
class MMLUTask(BaseTask[DictSample, ProcessedDictSample]):
"""
MMLU Task.
"""
def __init__(self, dataset, *args, **kwargs):
super().__init__(dataset, *args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def preprocess_sample(self, sample):
print(sample)
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample)
return {
"role": "user",
"content": content,
def preprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample.data)
preprocessed = {
"messages": [
{
"role": "user",
"content": content,
}
],
}
processed_sample = ProcessedDictSample(
data=sample.data,
preprocessed=preprocessed,
)
return processed_sample
def postprocess_sample(self, sample):
normalized = normalize_response(sample)
return normalized
def postprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
if not sample.postprocessed:
sample.postprocessed = {}
sample.postprocessed["postprocessed"] = normalize_response(
sample.prediction.completion_message
)
return sample
def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
postprocessed_output = sample.postprocessed["postprocessed"]
expected_answer = sample.data["Answer"]
def score_sample(self, sample, expected):
extracted_answer = None
for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
match = re.search(regex, sample)
match = re.search(regex, postprocessed_output)
if match:
extracted_answer = normalize_extracted_answer(match.group(1))
break
score = (
1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0
)
# TODO: generalize this into SingleEvalResult
return score
def aggregate_results(self, eval_results):
return EvaluateResponse(
metrics={"score": str(sum(eval_results) / len(eval_results))}
score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
return SingleEvalResult(
score_data={
"score": score,
},
)
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
print("aggregate_results", eval_results)
sum_score = sum([result.score_data["score"] for result in eval_results])
return EvalResult(metrics={"score": str(sum_score / len(eval_results))})