aggregation function config

This commit is contained in:
Xi Yan 2024-12-10 15:46:46 -08:00
parent e2054d53e4
commit fbc3888fd7
10 changed files with 189 additions and 26 deletions

View file

@ -113,7 +113,7 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_params)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -4,12 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from typing import Any, Dict, List, Optional
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from .fn_defs.equality import equality
@ -44,6 +45,15 @@ class EqualityScoringFn(BaseScoringFn):
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
self,
scoring_results: List[ScoringResultRow],
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)
aggregation_functions = [AggregationFunctionType.accuracy]
if (
scoring_params
and hasattr(scoring_params, "aggregation_functions")
and scoring_params.aggregation_functions
):
aggregation_functions.extend(scoring_params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)

View file

@ -5,11 +5,16 @@
# the root directory of this source tree.
import re
from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
ScoringFnParams,
ScoringFnParamsType,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
@ -62,6 +67,15 @@ class RegexParserScoringFn(BaseScoringFn):
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
self,
scoring_results: List[ScoringResultRow],
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)
aggregation_functions = [AggregationFunctionType.accuracy]
if (
scoring_params
and hasattr(scoring_params, "aggregation_functions")
and scoring_params.aggregation_functions
):
aggregation_functions.extend(scoring_params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)

View file

@ -4,11 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.subset_of import subset_of
@ -38,6 +39,15 @@ class SubsetOfScoringFn(BaseScoringFn):
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
self,
scoring_results: List[ScoringResultRow],
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)
aggregation_functions = [AggregationFunctionType.accuracy]
if (
scoring_params
and hasattr(scoring_params, "aggregation_functions")
and scoring_params.aggregation_functions
):
aggregation_functions.extend(scoring_params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)

View file

@ -120,7 +120,7 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_params)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -87,7 +87,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
self,
scoring_results: List[ScoringResultRow],
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
print(f"scoring_params: {scoring_params}")
# TODO: this needs to be config based aggregation, and only useful w/ Jobs API
return {}