mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
evals api mvp
This commit is contained in:
parent
3cbe3a72e8
commit
2441e66d14
3 changed files with 15 additions and 28 deletions
|
@ -35,8 +35,7 @@ class MetaReferenceEvalsImpl(Evals):
|
||||||
task: str,
|
task: str,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
|
cprint(f"model={model}, dataset={dataset}, task={task}", "red")
|
||||||
dataset = get_dataset("mmlu-simple-eval-en")
|
dataset = get_dataset(dataset)
|
||||||
|
|
||||||
task_impl = get_task(task, dataset)
|
task_impl = get_task(task, dataset)
|
||||||
x1 = task_impl.preprocess()
|
x1 = task_impl.preprocess()
|
||||||
|
|
||||||
|
@ -52,11 +51,6 @@ class MetaReferenceEvalsImpl(Evals):
|
||||||
generation_outputs.append(x.completion_message.content)
|
generation_outputs.append(x.completion_message.content)
|
||||||
|
|
||||||
x2 = task_impl.postprocess(generation_outputs)
|
x2 = task_impl.postprocess(generation_outputs)
|
||||||
scores = task_impl.score(x2)
|
eval_results = task_impl.score(x2)
|
||||||
print(scores)
|
eval_response = task_impl.aggregate_results(eval_results)
|
||||||
|
return eval_response
|
||||||
return EvaluateResponse(
|
|
||||||
metrics={
|
|
||||||
"accuracy": 0.5,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from .task import BaseTask
|
from .task import BaseTask
|
||||||
|
from llama_stack.apis.evals import * # noqa: F403
|
||||||
|
|
||||||
QUERY_TEMPLATE_MULTICHOICE = """
|
QUERY_TEMPLATE_MULTICHOICE = """
|
||||||
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
|
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
|
||||||
|
@ -119,9 +120,6 @@ class MMLUTask(BaseTask):
|
||||||
super().__init__(dataset, *args, **kwargs)
|
super().__init__(dataset, *args, **kwargs)
|
||||||
|
|
||||||
def preprocess_sample(self, sample):
|
def preprocess_sample(self, sample):
|
||||||
"""
|
|
||||||
F1: preprocess sample
|
|
||||||
"""
|
|
||||||
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample)
|
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample)
|
||||||
return {
|
return {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -129,16 +127,10 @@ class MMLUTask(BaseTask):
|
||||||
}
|
}
|
||||||
|
|
||||||
def postprocess_sample(self, sample):
|
def postprocess_sample(self, sample):
|
||||||
"""
|
|
||||||
F2: postprocess sample
|
|
||||||
"""
|
|
||||||
normalized = normalize_response(sample)
|
normalized = normalize_response(sample)
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
def score_sample(self, sample, expected):
|
def score_sample(self, sample, expected):
|
||||||
"""
|
|
||||||
F3: score sample
|
|
||||||
"""
|
|
||||||
extracted_answer = None
|
extracted_answer = None
|
||||||
for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
|
for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
|
||||||
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
|
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
|
||||||
|
@ -149,4 +141,10 @@ class MMLUTask(BaseTask):
|
||||||
score = (
|
score = (
|
||||||
1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0
|
1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0
|
||||||
)
|
)
|
||||||
|
# TODO: generalize this into SingleEvalResult
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
def aggregate_results(self, eval_results):
|
||||||
|
return EvaluateResponse(
|
||||||
|
metrics={"score": sum(eval_results) / len(eval_results)}
|
||||||
|
)
|
||||||
|
|
|
@ -21,23 +21,18 @@ class BaseTask(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def preprocess_sample(self, sample):
|
def preprocess_sample(self, sample):
|
||||||
"""
|
|
||||||
F1: preprocess sample
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def postprocess_sample(self, sample):
|
def postprocess_sample(self, sample):
|
||||||
"""
|
|
||||||
F2: postprocess sample
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def score_sample(self, sample, ground_truth):
|
def score_sample(self, sample, ground_truth):
|
||||||
"""
|
raise NotImplementedError()
|
||||||
F3: score sample
|
|
||||||
"""
|
@abstractmethod
|
||||||
|
def aggregate_results(self, eval_results):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue