mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-10 11:39:47 +00:00
RunEvalTask / InferenceGenerator
This commit is contained in:
parent
8890de7322
commit
78cb88c3c4
5 changed files with 111 additions and 84 deletions
|
|
@ -4,8 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
||||
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import * # noqa: F403
|
||||
from llama_stack.providers.impls.meta_reference.evals.generator.inference_generator import (
|
||||
InferenceGenerator,
|
||||
)
|
||||
from llama_stack.providers.impls.meta_reference.evals.processor.mmlu_processor import (
|
||||
MMLUProcessor,
|
||||
)
|
||||
|
||||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class RunEvalTask(BaseTask):
|
||||
|
|
@ -15,25 +24,51 @@ class RunEvalTask(BaseTask):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
eval_task_config,
|
||||
generator_processor: Optional[BaseGeneratorProcessor] = None,
|
||||
generator: Optional[BaseGenerator] = None,
|
||||
scorer: Optional[BaseScorer] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
generator_processor=generator_processor,
|
||||
generator=generator,
|
||||
scorer=scorer,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
self.eval_task_config = eval_task_config
|
||||
self.dataset = DatasetRegistry.get_dataset(
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
eval_task_config: EvaluateTaskConfig,
|
||||
inference_api: Inference,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> EvalResult:
|
||||
print(f"Running eval task w/ {eval_task_config}")
|
||||
|
||||
dataset = DatasetRegistry.get_dataset(
|
||||
eval_task_config.dataset_config.dataset_name
|
||||
)
|
||||
dataset.load(n_samples=eval_task_config.dataset_config.row_limit)
|
||||
print(f"Running on {len(dataset)} samples")
|
||||
|
||||
def run(self, *args, **kwargs) -> EvalResult:
|
||||
print(f"Running eval task on {self.dataset}")
|
||||
return EvalResult()
|
||||
# F1
|
||||
processor = MMLUProcessor()
|
||||
preprocessed = processor.preprocess(dataset)
|
||||
|
||||
# Generation
|
||||
generator = InferenceGenerator(
|
||||
model=eval_task_config.generation_config.model,
|
||||
inference_api=inference_api,
|
||||
)
|
||||
generation_outputs = await generator.generate(preprocessed)
|
||||
|
||||
# F2
|
||||
postprocessed = processor.postprocess(generation_outputs, dataset)
|
||||
cprint(postprocessed, "blue")
|
||||
|
||||
# F3 - scorer
|
||||
scorer = AggregateScorer(
|
||||
scorers=[
|
||||
AccuracyScorer(),
|
||||
RandomScorer(),
|
||||
]
|
||||
)
|
||||
|
||||
scorer_results = scorer.score(postprocessed)
|
||||
cprint(scorer_results, "magenta")
|
||||
eval_result = scorer.aggregate_results(scorer_results)
|
||||
|
||||
return eval_result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue