mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-10 19:43:16 +00:00
add data structure to tasks
This commit is contained in:
parent
9816c9aae6
commit
ad18dc94ac
7 changed files with 100 additions and 168 deletions
|
|
@ -6,7 +6,8 @@
|
|||
import re
|
||||
|
||||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from llama_stack.distribution.registry.tasks.task import BaseTask
|
||||
|
||||
# from llama_stack.distribution.registry.tasks.task import BaseTask
|
||||
|
||||
QUERY_TEMPLATE_MULTICHOICE = """
|
||||
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
|
||||
|
|
@ -111,48 +112,60 @@ def normalize_extracted_answer(extracted_answer: str) -> str:
|
|||
)
|
||||
|
||||
|
||||
class MMLUTask(
|
||||
BaseTask[
|
||||
DictSample,
|
||||
InferencePreprocessedSample,
|
||||
InferencePredictionSample,
|
||||
InferencePostprocessedSample,
|
||||
]
|
||||
):
|
||||
class MMLUTask(BaseTask[DictSample, ProcessedDictSample]):
|
||||
"""
|
||||
MMLU Task.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, *args, **kwargs):
|
||||
super().__init__(dataset, *args, **kwargs)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def preprocess_sample(self, sample):
|
||||
print(sample)
|
||||
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample)
|
||||
return {
|
||||
"role": "user",
|
||||
"content": content,
|
||||
def preprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
|
||||
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample.data)
|
||||
preprocessed = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
],
|
||||
}
|
||||
processed_sample = ProcessedDictSample(
|
||||
data=sample.data,
|
||||
preprocessed=preprocessed,
|
||||
)
|
||||
return processed_sample
|
||||
|
||||
def postprocess_sample(self, sample):
|
||||
normalized = normalize_response(sample)
|
||||
return normalized
|
||||
def postprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
|
||||
if not sample.postprocessed:
|
||||
sample.postprocessed = {}
|
||||
sample.postprocessed["postprocessed"] = normalize_response(
|
||||
sample.prediction.completion_message
|
||||
)
|
||||
return sample
|
||||
|
||||
def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
|
||||
postprocessed_output = sample.postprocessed["postprocessed"]
|
||||
expected_answer = sample.data["Answer"]
|
||||
|
||||
def score_sample(self, sample, expected):
|
||||
extracted_answer = None
|
||||
for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
|
||||
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
|
||||
match = re.search(regex, sample)
|
||||
match = re.search(regex, postprocessed_output)
|
||||
if match:
|
||||
extracted_answer = normalize_extracted_answer(match.group(1))
|
||||
break
|
||||
score = (
|
||||
1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0
|
||||
)
|
||||
# TODO: generalize this into SingleEvalResult
|
||||
return score
|
||||
|
||||
def aggregate_results(self, eval_results):
|
||||
return EvaluateResponse(
|
||||
metrics={"score": str(sum(eval_results) / len(eval_results))}
|
||||
score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
|
||||
|
||||
return SingleEvalResult(
|
||||
score_data={
|
||||
"score": score,
|
||||
},
|
||||
)
|
||||
|
||||
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
||||
print("aggregate_results", eval_results)
|
||||
sum_score = sum([result.score_data["score"] for result in eval_results])
|
||||
|
||||
return EvalResult(metrics={"score": str(sum_score / len(eval_results))})
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue