This commit is contained in:
Xi Yan 2024-12-10 16:26:27 -08:00
parent 1077c521b1
commit 7aab3c63f4
8 changed files with 30 additions and 18 deletions

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
@ -54,7 +54,7 @@ class EqualityScoringFn(BaseScoringFn):
if scoring_params is not None: if scoring_params is not None:
params = scoring_params params = scoring_params
aggregation_functions = [AggregationFunctionType.accuracy] aggregation_functions = []
if ( if (
params params
and hasattr(params, "aggregation_functions") and hasattr(params, "aggregation_functions")

View file

@ -5,14 +5,20 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
equality = ScoringFn( equality = ScoringFn(
identifier="basic::equality", identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
params=None,
provider_id="basic", provider_id="basic",
provider_resource_id="equality", provider_resource_id="equality",
return_type=NumberType(), return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
) )

View file

@ -4,9 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
RegexParserScoringFnParams,
ScoringFn,
)
MULTILINGUAL_ANSWER_REGEXES = [ MULTILINGUAL_ANSWER_REGEXES = [
r"Answer\s*:", r"Answer\s*:",
@ -67,5 +70,6 @@ regex_parser_multiple_choice_answer = ScoringFn(
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES for x in MULTILINGUAL_ANSWER_REGEXES
], ],
aggregation_functions=[AggregationFunctionType.accuracy],
), ),
) )

View file

@ -5,7 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
subset_of = ScoringFn( subset_of = ScoringFn(
@ -14,4 +18,7 @@ subset_of = ScoringFn(
return_type=NumberType(), return_type=NumberType(),
provider_id="basic", provider_id="basic",
provider_resource_id="subset-of", provider_resource_id="subset-of",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
) )

View file

@ -8,11 +8,7 @@ import re
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ( from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
AggregationFunctionType,
ScoringFnParams,
ScoringFnParamsType,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
@ -76,7 +72,7 @@ class RegexParserScoringFn(BaseScoringFn):
if scoring_params is not None: if scoring_params is not None:
params = scoring_params params = scoring_params
aggregation_functions = [AggregationFunctionType.accuracy] aggregation_functions = []
if ( if (
params params
and hasattr(params, "aggregation_functions") and hasattr(params, "aggregation_functions")

View file

@ -7,7 +7,7 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
@ -48,7 +48,7 @@ class SubsetOfScoringFn(BaseScoringFn):
if scoring_params is not None: if scoring_params is not None:
params = scoring_params params = scoring_params
aggregation_functions = [AggregationFunctionType.accuracy] aggregation_functions = []
if ( if (
params params
and hasattr(params, "aggregation_functions") and hasattr(params, "aggregation_functions")

View file

@ -147,7 +147,7 @@ class BraintrustScoringImpl(
await self.score_row(input_row, scoring_fn_id) await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows for input_row in input_rows
] ]
aggregation_functions = [AggregationFunctionType.average]
agg_results = aggregate_average(score_results) agg_results = aggregate_average(score_results)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,

View file

@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn
class BaseScoringFn(ABC): class BaseScoringFn(ABC):
""" """
Base interface class for all meta-reference scoring_fns. Base interface class for all native scoring_fns.
Each scoring_fn needs to implement the following methods: Each scoring_fn needs to implement the following methods:
- score_row(self, row) - score_row(self, row)
- aggregate(self, scoring_fn_results) - aggregate(self, scoring_fn_results)
@ -20,7 +20,6 @@ class BaseScoringFn(ABC):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {}
def __str__(self) -> str: def __str__(self) -> str:
return self.__class__.__name__ return self.__class__.__name__