This commit is contained in:
Xi Yan 2024-10-24 14:00:41 -07:00
parent 3db1b3fbcd
commit ba0186f2c8
3 changed files with 27 additions and 18 deletions

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_stack.apis.scoring import ScoringResultRow
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}

View file

@ -10,6 +10,9 @@ from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer impor
from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scorer.common import (
aggregate_accuracy,
)
class EqualityScorer(BaseScorer): class EqualityScorer(BaseScorer):
@ -38,12 +41,4 @@ class EqualityScorer(BaseScorer):
} }
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
assert len(scoring_results) > 0, "Empty scoring results provided." return aggregate_accuracy(scoring_results)
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}

View file

@ -10,6 +10,9 @@ from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer impor
from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.impls.meta_reference.scoring.scorer.common import (
aggregate_accuracy,
)
class InclusionScorer(BaseScorer): class InclusionScorer(BaseScorer):
@ -38,12 +41,4 @@ class InclusionScorer(BaseScorer):
} }
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
assert len(scoring_results) > 0, "Empty scoring results provided." return aggregate_accuracy(scoring_results)
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}