From ba0186f2c8f742cc20599190dfba3106e92fc471 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 24 Oct 2024 14:00:41 -0700 Subject: [PATCH] refactor --- .../meta_reference/scoring/scorer/common.py | 19 +++++++++++++++++++ .../scoring/scorer/equality_scorer.py | 13 ++++--------- .../scoring/scorer/inclusion_scorer.py | 13 ++++--------- 3 files changed, 27 insertions(+), 18 deletions(-) create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scorer/common.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/common.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/common.py new file mode 100644 index 000000000..52eabea2e --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/common.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, List + +from llama_stack.apis.scoring import ScoringResultRow + + +def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + num_correct = sum(result["score"] for result in scoring_results) + avg_score = num_correct / len(scoring_results) + + return { + "accuracy": avg_score, + "num_correct": num_correct, + "num_total": len(scoring_results), + } diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py index ce765bfb5..0c7751f35 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py @@ -10,6 +10,9 @@ from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer impor from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.providers.impls.meta_reference.scoring.scorer.common import ( + aggregate_accuracy, +) class EqualityScorer(BaseScorer): @@ -38,12 +41,4 @@ class EqualityScorer(BaseScorer): } def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: - assert len(scoring_results) > 0, "Empty scoring results provided." - num_correct = sum(result["score"] for result in scoring_results) - avg_score = num_correct / len(scoring_results) - - return { - "accuracy": avg_score, - "num_correct": num_correct, - "num_total": len(scoring_results), - } + return aggregate_accuracy(scoring_results) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py index cb0c3fc6d..651fbf65e 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/inclusion_scorer.py @@ -10,6 +10,9 @@ from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer impor from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.providers.impls.meta_reference.scoring.scorer.common import ( + aggregate_accuracy, +) class InclusionScorer(BaseScorer): @@ -38,12 +41,4 @@ class InclusionScorer(BaseScorer): } def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: - assert len(scoring_results) > 0, "Empty scoring results provided." - num_correct = sum(result["score"] for result in scoring_results) - avg_score = num_correct / len(scoring_results) - - return { - "accuracy": avg_score, - "num_correct": num_correct, - "num_total": len(scoring_results), - } + return aggregate_accuracy(scoring_results)