scorer registry

This commit is contained in:
Xi Yan 2024-10-14 15:41:31 -07:00
parent 9c501d042b
commit c50686b6fe
5 changed files with 55 additions and 32 deletions

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.registry.datasets import DatasetRegistry
from llama_stack.distribution.registry.scorers import ScorerRegistry
from llama_stack.providers.impls.meta_reference.evals.scorer.aggregate_scorer import * # noqa: F403
from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers import * # noqa: F403
from llama_stack.providers.impls.meta_reference.evals.generator.inference_generator import (
InferenceGenerator,
@ -59,11 +61,14 @@ class RunEvalTask(BaseTask):
cprint(postprocessed, "blue")
# F3 - scorer
scorer_config_list = eval_task_config.scoring_config.scorer_config_list
scorer_list = []
for s_conf in scorer_config_list:
scorer = ScorerRegistry.get(s_conf.scorer_name)
scorer_list.append(scorer())
scorer = AggregateScorer(
scorers=[
AccuracyScorer(),
RandomScorer(),
]
scorers=scorer_list,
)
scorer_results = scorer.score(postprocessed)