mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-09 19:29:18 +00:00
RunEvalTask / InferenceGenerator
This commit is contained in:
parent
8890de7322
commit
78cb88c3c4
5 changed files with 111 additions and 84 deletions
|
|
@ -7,25 +7,14 @@ import json
|
|||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import (
|
||||
AggregateScorer,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.evals import * # noqa: F403
|
||||
from llama_stack.apis.dataset import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.registry.datasets.dataset_registry import DatasetRegistry
|
||||
from llama_stack.providers.impls.meta_reference.evals.processor.mmlu_processor import (
|
||||
MMLUProcessor,
|
||||
)
|
||||
from .config import MetaReferenceEvalsImplConfig
|
||||
|
||||
# from llama_stack.distribution.registry.tasks.task_registry import TaskRegistry
|
||||
# from .tasks.run_eval_task import RunEvalTask
|
||||
from .scorer.basic_scorers import * # noqa: F403
|
||||
|
||||
|
||||
from .config import MetaReferenceEvalsImplConfig
|
||||
from .tasks.run_eval_task import RunEvalTask
|
||||
|
||||
|
||||
class MetaReferenceEvalsImpl(Evals):
|
||||
|
|
@ -70,58 +59,8 @@ class MetaReferenceEvalsImpl(Evals):
|
|||
),
|
||||
)
|
||||
|
||||
# TODO: wrap inside task
|
||||
# run_task = RunEvalTask(
|
||||
# eval_task_config=eval_task_config,
|
||||
# )
|
||||
# eval_result = run_task.run()
|
||||
|
||||
dataset = DatasetRegistry.get_dataset(
|
||||
eval_task_config.dataset_config.dataset_name
|
||||
)
|
||||
dataset.load(n_samples=eval_task_config.dataset_config.row_limit)
|
||||
print(f"Running on {len(dataset)} samples")
|
||||
|
||||
# F1
|
||||
processor = MMLUProcessor()
|
||||
preprocessed = processor.preprocess(dataset)
|
||||
|
||||
# Generation
|
||||
# TODO: wrap inside BaseGenerator
|
||||
generation_outputs = []
|
||||
for sample in preprocessed:
|
||||
print("generation: ", sample)
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=model,
|
||||
messages=sample.generation_input.messages,
|
||||
stream=False,
|
||||
)
|
||||
cprint(f"response: {response}", "cyan")
|
||||
|
||||
generation_outputs.append(
|
||||
GenerationResponseSample(
|
||||
generation_output=GenerationOutput(
|
||||
completion_message=response.completion_message.content
|
||||
)
|
||||
)
|
||||
)
|
||||
cprint(generation_outputs, "green")
|
||||
|
||||
# F2
|
||||
postprocessed = processor.postprocess(generation_outputs, dataset)
|
||||
cprint(postprocessed, "blue")
|
||||
|
||||
# F3 - scorer
|
||||
scorer = AggregateScorer(
|
||||
scorers=[
|
||||
AccuracyScorer(),
|
||||
RandomScorer(),
|
||||
]
|
||||
)
|
||||
|
||||
scorer_results = scorer.score(postprocessed)
|
||||
cprint(scorer_results, "magenta")
|
||||
eval_result = scorer.aggregate_results(scorer_results)
|
||||
run_task = RunEvalTask()
|
||||
eval_result = await run_task.run(eval_task_config, self.inference_api)
|
||||
|
||||
return EvaluateResponse(
|
||||
eval_result=eval_result,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue