aggregation function config

This commit is contained in:
Xi Yan 2024-12-10 15:46:46 -08:00
parent e2054d53e4
commit fbc3888fd7
10 changed files with 189 additions and 26 deletions

View file

@ -3,9 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import statistics
from typing import Any, Dict, List
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
@ -26,3 +27,38 @@ def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]
)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}
def aggregate_categorical_count(
scoring_results: List[ScoringResultRow],
) -> Dict[str, Any]:
scores = [str(r["score"]) for r in scoring_results]
unique_scores = sorted(list(set(scores)))
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
scores = [r["score"] for r in scoring_results if r["score"] is not None]
median = statistics.median(scores) if scores else None
return {"median": median}
# TODO: decide whether we want to make aggregation functions as a registerable resource
AGGREGATION_FUNCTIONS = {
AggregationFunctionType.accuracy: aggregate_accuracy,
AggregationFunctionType.average: aggregate_average,
AggregationFunctionType.categorical_count: aggregate_categorical_count,
AggregationFunctionType.median: aggregate_median,
}
def aggregate_metrics(
scoring_results: List[ScoringResultRow], metrics: List[AggregationFunctionType]
) -> Dict[str, Any]:
agg_results = {}
for metric in metrics:
if metric not in AGGREGATION_FUNCTIONS:
raise ValueError(f"Aggregation function {metric} not found")
agg_fn = AGGREGATION_FUNCTIONS[metric]
agg_results[metric] = agg_fn(scoring_results)
return agg_results

View file

@ -46,7 +46,9 @@ class BaseScoringFn(ABC):
@abstractmethod
async def aggregate(
self, scoring_results: List[ScoringResultRow]
self,
scoring_results: List[ScoringResultRow],
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
raise NotImplementedError()