From 7aab3c63f4d1f05a13f03bb3ddaace8ff0ad76b8 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 10 Dec 2024 16:26:27 -0800 Subject: [PATCH] refactor --- .../scoring/basic/scoring_fn/equality_scoring_fn.py | 4 ++-- .../scoring/basic/scoring_fn/fn_defs/equality.py | 10 ++++++++-- .../fn_defs/regex_parser_multiple_choice_answer.py | 8 ++++++-- .../scoring/basic/scoring_fn/fn_defs/subset_of.py | 9 ++++++++- .../basic/scoring_fn/regex_parser_scoring_fn.py | 8 ++------ .../scoring/basic/scoring_fn/subset_of_scoring_fn.py | 4 ++-- .../providers/inline/scoring/braintrust/braintrust.py | 2 +- llama_stack/providers/utils/scoring/base_scoring_fn.py | 3 +-- 8 files changed, 30 insertions(+), 18 deletions(-) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py index f1a063910..0a3fe877f 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams +from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn @@ -54,7 +54,7 @@ class EqualityScoringFn(BaseScoringFn): if scoring_params is not None: params = scoring_params - aggregation_functions = [AggregationFunctionType.accuracy] + aggregation_functions = [] if ( params and hasattr(params, "aggregation_functions") diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py index 8403119f6..c20171829 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py @@ -5,14 +5,20 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) equality = ScoringFn( identifier="basic::equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", - params=None, provider_id="basic", provider_resource_id="equality", return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.accuracy] + ), ) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py index 9d028a468..b7a649a48 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py @@ -4,9 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + RegexParserScoringFnParams, + ScoringFn, +) MULTILINGUAL_ANSWER_REGEXES = [ r"Answer\s*:", @@ -67,5 +70,6 @@ regex_parser_multiple_choice_answer = ScoringFn( MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) for x in MULTILINGUAL_ANSWER_REGEXES ], + aggregation_functions=[AggregationFunctionType.accuracy], ), ) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py index ab2a9c60b..98f54afb5 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py @@ -5,7 +5,11 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) subset_of = ScoringFn( @@ -14,4 +18,7 @@ subset_of = ScoringFn( return_type=NumberType(), provider_id="basic", provider_resource_id="subset-of", + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.accuracy] + ), ) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index ff0602d76..326d090f8 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -8,11 +8,7 @@ import re from typing import Any, Dict, List, Optional from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import ( - AggregationFunctionType, - ScoringFnParams, - ScoringFnParamsType, -) +from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn @@ -76,7 +72,7 @@ class RegexParserScoringFn(BaseScoringFn): if scoring_params is not None: params = scoring_params - aggregation_functions = [AggregationFunctionType.accuracy] + aggregation_functions = [] if ( params and hasattr(params, "aggregation_functions") diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py index 3ab456a63..e813cbe6d 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional from llama_stack.apis.scoring import ScoringResultRow -from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams +from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn @@ -48,7 +48,7 @@ class SubsetOfScoringFn(BaseScoringFn): if scoring_params is not None: params = scoring_params - aggregation_functions = [AggregationFunctionType.accuracy] + aggregation_functions = [] if ( params and hasattr(params, "aggregation_functions") diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index 8b22a8930..ae9555403 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -147,7 +147,7 @@ class BraintrustScoringImpl( await self.score_row(input_row, scoring_fn_id) for input_row in input_rows ] - + aggregation_functions = [AggregationFunctionType.average] agg_results = aggregate_average(score_results) res[scoring_fn_id] = ScoringResult( score_rows=score_results, diff --git a/llama_stack/providers/utils/scoring/base_scoring_fn.py b/llama_stack/providers/utils/scoring/base_scoring_fn.py index 180abf6cd..56ec4dfae 100644 --- a/llama_stack/providers/utils/scoring/base_scoring_fn.py +++ b/llama_stack/providers/utils/scoring/base_scoring_fn.py @@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn class BaseScoringFn(ABC): """ - Base interface class for all meta-reference scoring_fns. + Base interface class for all native scoring_fns. Each scoring_fn needs to implement the following methods: - score_row(self, row) - aggregate(self, scoring_fn_results) @@ -20,7 +20,6 @@ class BaseScoringFn(ABC): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.supported_fn_defs_registry = {} def __str__(self) -> str: return self.__class__.__name__