evals api mvp

This commit is contained in:
Xi Yan 2024-10-04 00:50:03 -07:00
parent 3cbe3a72e8
commit 2441e66d14
3 changed files with 15 additions and 28 deletions

View file

@ -35,8 +35,7 @@ class MetaReferenceEvalsImpl(Evals):
task: str, task: str,
) -> EvaluateResponse: ) -> EvaluateResponse:
cprint(f"model={model}, dataset={dataset}, task={task}", "red") cprint(f"model={model}, dataset={dataset}, task={task}", "red")
dataset = get_dataset("mmlu-simple-eval-en") dataset = get_dataset(dataset)
task_impl = get_task(task, dataset) task_impl = get_task(task, dataset)
x1 = task_impl.preprocess() x1 = task_impl.preprocess()
@ -52,11 +51,6 @@ class MetaReferenceEvalsImpl(Evals):
generation_outputs.append(x.completion_message.content) generation_outputs.append(x.completion_message.content)
x2 = task_impl.postprocess(generation_outputs) x2 = task_impl.postprocess(generation_outputs)
scores = task_impl.score(x2) eval_results = task_impl.score(x2)
print(scores) eval_response = task_impl.aggregate_results(eval_results)
return eval_response
return EvaluateResponse(
metrics={
"accuracy": 0.5,
}
)

View file

@ -6,6 +6,7 @@
import re import re
from .task import BaseTask from .task import BaseTask
from llama_stack.apis.evals import * # noqa: F403
QUERY_TEMPLATE_MULTICHOICE = """ QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
@ -119,9 +120,6 @@ class MMLUTask(BaseTask):
super().__init__(dataset, *args, **kwargs) super().__init__(dataset, *args, **kwargs)
def preprocess_sample(self, sample): def preprocess_sample(self, sample):
"""
F1: preprocess sample
"""
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample) content = QUERY_TEMPLATE_MULTICHOICE.format(**sample)
return { return {
"role": "user", "role": "user",
@ -129,16 +127,10 @@ class MMLUTask(BaseTask):
} }
def postprocess_sample(self, sample): def postprocess_sample(self, sample):
"""
F2: postprocess sample
"""
normalized = normalize_response(sample) normalized = normalize_response(sample)
return normalized return normalized
def score_sample(self, sample, expected): def score_sample(self, sample, expected):
"""
F3: score sample
"""
extracted_answer = None extracted_answer = None
for answer_regex in MULTILINGUAL_ANSWER_REGEXES: for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex) regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
@ -149,4 +141,10 @@ class MMLUTask(BaseTask):
score = ( score = (
1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0 1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0
) )
# TODO: generalize this into SingleEvalResult
return score return score
def aggregate_results(self, eval_results):
return EvaluateResponse(
metrics={"score": sum(eval_results) / len(eval_results)}
)

View file

@ -21,23 +21,18 @@ class BaseTask(ABC):
@abstractmethod @abstractmethod
def preprocess_sample(self, sample): def preprocess_sample(self, sample):
"""
F1: preprocess sample
"""
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def postprocess_sample(self, sample): def postprocess_sample(self, sample):
"""
F2: postprocess sample
"""
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def score_sample(self, sample, ground_truth): def score_sample(self, sample, ground_truth):
""" raise NotImplementedError()
F3: score sample
""" @abstractmethod
def aggregate_results(self, eval_results):
raise NotImplementedError() raise NotImplementedError()
def preprocess(self): def preprocess(self):