From 199f92ddad0f1e53f71e3e439abb6af01f0c09a5 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 19 Dec 2024 12:34:49 -0800 Subject: [PATCH] refactor base scoring fn v.s. registerable scoring fn --- .../basic/scoring_fn/regex_parser_scoring_fn.py | 4 ++-- .../scoring/basic/scoring_fn/subset_of_scoring_fn.py | 4 ++-- .../providers/inline/scoring/braintrust/braintrust.py | 9 ++++++--- .../scoring_fn/fn_defs/answer_correctness.py | 10 ++++++++-- .../braintrust/scoring_fn/fn_defs/factuality.py | 10 ++++++++-- 5 files changed, 26 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index 552f34d46..38014ca6f 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -9,14 +9,14 @@ from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType -from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from .fn_defs.regex_parser_multiple_choice_answer import ( regex_parser_multiple_choice_answer, ) -class RegexParserScoringFn(BaseScoringFn): +class RegexParserScoringFn(RegisteredBaseScoringFn): """ A scoring_fn that parses answer from generated response according to context and check match with expected_answer. """ diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py index 29ae12e44..71defc433 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py @@ -8,12 +8,12 @@ from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from .fn_defs.subset_of import subset_of -class SubsetOfScoringFn(BaseScoringFn): +class SubsetOfScoringFn(RegisteredBaseScoringFn): """ A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. """ diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index ae9555403..fcb48fd33 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -20,7 +20,7 @@ from autoevals.ragas import AnswerCorrectness from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from .config import BraintrustScoringConfig from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def @@ -147,8 +147,11 @@ class BraintrustScoringImpl( await self.score_row(input_row, scoring_fn_id) for input_row in input_rows ] - aggregation_functions = [AggregationFunctionType.average] - agg_results = aggregate_average(score_results) + aggregation_functions = self.supported_fn_defs_registry[ + scoring_fn_id + ].params.aggregation_functions + + agg_results = aggregate_metrics(score_results, aggregation_functions) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py index dc5df8e78..becadf97b 100644 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py @@ -5,14 +5,20 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) answer_correctness_fn_def = ScoringFn( identifier="braintrust::answer-correctness", description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", - params=None, provider_id="braintrust", provider_resource_id="answer-correctness", return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), ) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py index b733f10c8..88cd7b3a7 100644 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py @@ -5,14 +5,20 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) factuality_fn_def = ScoringFn( identifier="braintrust::factuality", description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", - params=None, provider_id="braintrust", provider_resource_id="factuality", return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), )