RunEvalTask / InferenceGenerator

This commit is contained in:
Xi Yan 2024-10-13 23:48:15 -07:00
parent 8890de7322
commit 78cb88c3c4
5 changed files with 111 additions and 84 deletions

View file

@ -4,8 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import * # noqa: F403
from llama_stack.providers.impls.meta_reference.evals.generator.inference_generator import (
InferenceGenerator,
)
from llama_stack.providers.impls.meta_reference.evals.processor.mmlu_processor import (
MMLUProcessor,
)
from llama_stack.apis.evals import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from termcolor import cprint
class RunEvalTask(BaseTask):
@ -15,25 +24,51 @@ class RunEvalTask(BaseTask):
def __init__(
self,
eval_task_config,
generator_processor: Optional[BaseGeneratorProcessor] = None,
generator: Optional[BaseGenerator] = None,
scorer: Optional[BaseScorer] = None,
*args,
**kwargs,
) -> None:
super().__init__(
generator_processor=generator_processor,
generator=generator,
scorer=scorer,
*args,
**kwargs,
)
self.eval_task_config = eval_task_config
self.dataset = DatasetRegistry.get_dataset(
super().__init__(*args, **kwargs)
async def run(
self,
eval_task_config: EvaluateTaskConfig,
inference_api: Inference,
*args,
**kwargs,
) -> EvalResult:
print(f"Running eval task w/ {eval_task_config}")
dataset = DatasetRegistry.get_dataset(
eval_task_config.dataset_config.dataset_name
)
dataset.load(n_samples=eval_task_config.dataset_config.row_limit)
print(f"Running on {len(dataset)} samples")
def run(self, *args, **kwargs) -> EvalResult:
print(f"Running eval task on {self.dataset}")
return EvalResult()
# F1
processor = MMLUProcessor()
preprocessed = processor.preprocess(dataset)
# Generation
generator = InferenceGenerator(
model=eval_task_config.generation_config.model,
inference_api=inference_api,
)
generation_outputs = await generator.generate(preprocessed)
# F2
postprocessed = processor.postprocess(generation_outputs, dataset)
cprint(postprocessed, "blue")
# F3 - scorer
scorer = AggregateScorer(
scorers=[
AccuracyScorer(),
RandomScorer(),
]
)
scorer_results = scorer.score(postprocessed)
cprint(scorer_results, "magenta")
eval_result = scorer.aggregate_results(scorer_results)
return eval_result