This commit is contained in:
Xi Yan 2024-12-10 16:26:27 -08:00
parent 1077c521b1
commit 7aab3c63f4
8 changed files with 30 additions and 18 deletions

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
@ -54,7 +54,7 @@ class EqualityScoringFn(BaseScoringFn):
if scoring_params is not None:
params = scoring_params
aggregation_functions = [AggregationFunctionType.accuracy]
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")

View file

@ -5,14 +5,20 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
equality = ScoringFn(
identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
params=None,
provider_id="basic",
provider_resource_id="equality",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
)

View file

@ -4,9 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
RegexParserScoringFnParams,
ScoringFn,
)
MULTILINGUAL_ANSWER_REGEXES = [
r"Answer\s*:",
@ -67,5 +70,6 @@ regex_parser_multiple_choice_answer = ScoringFn(
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES
],
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -5,7 +5,11 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
subset_of = ScoringFn(
@ -14,4 +18,7 @@ subset_of = ScoringFn(
return_type=NumberType(),
provider_id="basic",
provider_resource_id="subset-of",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
)

View file

@ -8,11 +8,7 @@ import re
from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
ScoringFnParams,
ScoringFnParamsType,
)
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
@ -76,7 +72,7 @@ class RegexParserScoringFn(BaseScoringFn):
if scoring_params is not None:
params = scoring_params
aggregation_functions = [AggregationFunctionType.accuracy]
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")

View file

@ -7,7 +7,7 @@
from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
@ -48,7 +48,7 @@ class SubsetOfScoringFn(BaseScoringFn):
if scoring_params is not None:
params = scoring_params
aggregation_functions = [AggregationFunctionType.accuracy]
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")

View file

@ -147,7 +147,7 @@ class BraintrustScoringImpl(
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
aggregation_functions = [AggregationFunctionType.average]
agg_results = aggregate_average(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,

View file

@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn
class BaseScoringFn(ABC):
"""
Base interface class for all meta-reference scoring_fns.
Base interface class for all native scoring_fns.
Each scoring_fn needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scoring_fn_results)
@ -20,7 +20,6 @@ class BaseScoringFn(ABC):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {}
def __str__(self) -> str:
return self.__class__.__name__