mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-09 19:29:18 +00:00
add data structure to tasks
This commit is contained in:
parent
9816c9aae6
commit
ad18dc94ac
7 changed files with 100 additions and 168 deletions
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from llama_stack.apis.dataset import * # noqa: F403
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
||||
|
|
@ -42,30 +44,37 @@ class MetaReferenceEvalsImpl(Evals):
|
|||
dataset = DatasetRegistry.get_dataset(dataset)
|
||||
dataset.load()
|
||||
|
||||
task_impl = TaskRegistry.get_task(task)(dataset)
|
||||
x1 = task_impl.preprocess()
|
||||
task_impl = TaskRegistry.get_task(task)()
|
||||
preprocessed = task_impl.preprocess(dataset)
|
||||
|
||||
# TODO: replace w/ batch inference & async return eval job
|
||||
generation_outputs = []
|
||||
if eval_task_config is None:
|
||||
eval_task_config = EvaluateTaskConfig(n_samples=len(x1))
|
||||
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(x1):
|
||||
eval_task_config.n_samples = len(x1)
|
||||
eval_task_config = EvaluateTaskConfig(n_samples=len(preprocessed))
|
||||
if eval_task_config.n_samples is None or eval_task_config.n_samples > len(
|
||||
preprocessed
|
||||
):
|
||||
eval_task_config.n_samples = len(preprocessed)
|
||||
|
||||
print(
|
||||
f"Eval generation start, generate on {eval_task_config.n_samples} samples"
|
||||
)
|
||||
|
||||
for msg in x1[: eval_task_config.n_samples]:
|
||||
print("generation for msg: ", msg)
|
||||
for sample in preprocessed[: eval_task_config.n_samples]:
|
||||
print("generation: ", sample)
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=model,
|
||||
messages=[msg],
|
||||
messages=sample.preprocessed["messages"],
|
||||
stream=False,
|
||||
)
|
||||
generation_outputs.append(response.completion_message.content)
|
||||
sample.prediction = PredictionSample(
|
||||
completion_message=response.completion_message.content
|
||||
)
|
||||
generation_outputs.append(sample)
|
||||
|
||||
x2 = task_impl.postprocess(generation_outputs)
|
||||
eval_results = task_impl.score(x2)
|
||||
eval_response = task_impl.aggregate_results(eval_results)
|
||||
return eval_response
|
||||
postprocessed = task_impl.postprocess(generation_outputs)
|
||||
eval_results = task_impl.score(postprocessed)
|
||||
aggr_result = task_impl.aggregate_results(eval_results)
|
||||
return EvaluateResponse(
|
||||
eval_result=aggr_result,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@
|
|||
import re
|
||||
|
||||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from llama_stack.distribution.registry.tasks.task import BaseTask
|
||||
|
||||
# from llama_stack.distribution.registry.tasks.task import BaseTask
|
||||
|
||||
QUERY_TEMPLATE_MULTICHOICE = """
|
||||
Answer the following multiple choice question and make the answer very simple. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.
|
||||
|
|
@ -111,48 +112,60 @@ def normalize_extracted_answer(extracted_answer: str) -> str:
|
|||
)
|
||||
|
||||
|
||||
class MMLUTask(
|
||||
BaseTask[
|
||||
DictSample,
|
||||
InferencePreprocessedSample,
|
||||
InferencePredictionSample,
|
||||
InferencePostprocessedSample,
|
||||
]
|
||||
):
|
||||
class MMLUTask(BaseTask[DictSample, ProcessedDictSample]):
|
||||
"""
|
||||
MMLU Task.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, *args, **kwargs):
|
||||
super().__init__(dataset, *args, **kwargs)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def preprocess_sample(self, sample):
|
||||
print(sample)
|
||||
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample)
|
||||
return {
|
||||
"role": "user",
|
||||
"content": content,
|
||||
def preprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
|
||||
content = QUERY_TEMPLATE_MULTICHOICE.format(**sample.data)
|
||||
preprocessed = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
],
|
||||
}
|
||||
processed_sample = ProcessedDictSample(
|
||||
data=sample.data,
|
||||
preprocessed=preprocessed,
|
||||
)
|
||||
return processed_sample
|
||||
|
||||
def postprocess_sample(self, sample):
|
||||
normalized = normalize_response(sample)
|
||||
return normalized
|
||||
def postprocess_sample(self, sample: ProcessedDictSample) -> ProcessedDictSample:
|
||||
if not sample.postprocessed:
|
||||
sample.postprocessed = {}
|
||||
sample.postprocessed["postprocessed"] = normalize_response(
|
||||
sample.prediction.completion_message
|
||||
)
|
||||
return sample
|
||||
|
||||
def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
|
||||
postprocessed_output = sample.postprocessed["postprocessed"]
|
||||
expected_answer = sample.data["Answer"]
|
||||
|
||||
def score_sample(self, sample, expected):
|
||||
extracted_answer = None
|
||||
for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
|
||||
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
|
||||
match = re.search(regex, sample)
|
||||
match = re.search(regex, postprocessed_output)
|
||||
if match:
|
||||
extracted_answer = normalize_extracted_answer(match.group(1))
|
||||
break
|
||||
score = (
|
||||
1.0 if extracted_answer and extracted_answer == expected["Answer"] else 0.0
|
||||
)
|
||||
# TODO: generalize this into SingleEvalResult
|
||||
return score
|
||||
|
||||
def aggregate_results(self, eval_results):
|
||||
return EvaluateResponse(
|
||||
metrics={"score": str(sum(eval_results) / len(eval_results))}
|
||||
score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
|
||||
|
||||
return SingleEvalResult(
|
||||
score_data={
|
||||
"score": score,
|
||||
},
|
||||
)
|
||||
|
||||
def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
||||
print("aggregate_results", eval_results)
|
||||
sum_score = sum([result.score_data["score"] for result in eval_results])
|
||||
|
||||
return EvalResult(metrics={"score": str(sum_score / len(eval_results))})
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue