RunEvalTask / InferenceGenerator

This commit is contained in:
Xi Yan 2024-10-13 23:48:15 -07:00
parent 8890de7322
commit 78cb88c3c4
5 changed files with 111 additions and 84 deletions

View file

@ -7,25 +7,14 @@ import json
from termcolor import cprint
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import (
AggregateScorer,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
from llama_stack.providers.impls.meta_reference.evals.processor.mmlu_processor import (
MMLUProcessor,
)
from .config import MetaReferenceEvalsImplConfig
# from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
# from .tasks.run_eval_task import RunEvalTask
from .scorer.basic_scorers import * # noqa: F403
from .config import MetaReferenceEvalsImplConfig
from .tasks.run_eval_task import RunEvalTask
class MetaReferenceEvalsImpl(Evals):
@ -70,58 +59,8 @@ class MetaReferenceEvalsImpl(Evals):
),
)
# TODO: wrap inside task
# run_task = RunEvalTask(
# eval_task_config=eval_task_config,
# )
# eval_result = run_task.run()
dataset = DatasetRegistry.get_dataset(
eval_task_config.dataset_config.dataset_name
)
dataset.load(n_samples=eval_task_config.dataset_config.row_limit)
print(f"Running on {len(dataset)} samples")
# F1
processor = MMLUProcessor()
preprocessed = processor.preprocess(dataset)
# Generation
# TODO: wrap inside BaseGenerator
generation_outputs = []
for sample in preprocessed:
print("generation: ", sample)
response = await self.inference_api.chat_completion(
model=model,
messages=sample.generation_input.messages,
stream=False,
)
cprint(f"response: {response}", "cyan")
generation_outputs.append(
GenerationResponseSample(
generation_output=GenerationOutput(
completion_message=response.completion_message.content
)
)
)
cprint(generation_outputs, "green")
# F2
postprocessed = processor.postprocess(generation_outputs, dataset)
cprint(postprocessed, "blue")
# F3 - scorer
scorer = AggregateScorer(
scorers=[
AccuracyScorer(),
RandomScorer(),
]
)
scorer_results = scorer.score(postprocessed)
cprint(scorer_results, "magenta")
eval_result = scorer.aggregate_results(scorer_results)
run_task = RunEvalTask()
eval_result = await run_task.run(eval_task_config, self.inference_api)
return EvaluateResponse(
eval_result=eval_result,