[/scoring] add ability to define aggregation functions for scoring functions & refactors (#597)

# What does this PR do?

- Add ability to define aggregation functions for scoring functions via
`ScoringFnParams`
- Supported by `basic` / `regex_parser` / `llm_as_judge` scoring
functions


## Test Plan

```
pytest -v -s -m basic_scoring_together_inference scoring/test_scoring.py
```
<img width="855" alt="image"
src="https://github.com/user-attachments/assets/12db8e6e-2ad4-462e-b9b9-70ba6c050a6c">


```
pytest -v -s -m llm_as_judge_scoring_together_inference scoring/test_scoring.py
```
<img width="858" alt="image"
src="https://github.com/user-attachments/assets/bf806676-6f5e-456d-be9f-f81a26d1df19">



**Example Response** (`basic`)
<img width="863" alt="image"
src="https://github.com/user-attachments/assets/0e57a49c-8386-45cc-8fa9-3e61aaa9a3be">

**Example Response** (`llm-as-judge`)
<img width="854" alt="image"
src="https://github.com/user-attachments/assets/38065bc2-b724-47ed-9535-79b6099c4362">


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2024-12-11 10:03:42 -08:00 committed by GitHub
parent e128f2547a
commit a4bcfb8bba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 323 additions and 55 deletions

View file

@ -3,9 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import statistics
from typing import Any, Dict, List
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
@ -26,3 +27,38 @@ def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]
)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}
def aggregate_categorical_count(
scoring_results: List[ScoringResultRow],
) -> Dict[str, Any]:
scores = [str(r["score"]) for r in scoring_results]
unique_scores = sorted(list(set(scores)))
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
scores = [r["score"] for r in scoring_results if r["score"] is not None]
median = statistics.median(scores) if scores else None
return {"median": median}
# TODO: decide whether we want to make aggregation functions as a registerable resource
AGGREGATION_FUNCTIONS = {
AggregationFunctionType.accuracy: aggregate_accuracy,
AggregationFunctionType.average: aggregate_average,
AggregationFunctionType.categorical_count: aggregate_categorical_count,
AggregationFunctionType.median: aggregate_median,
}
def aggregate_metrics(
scoring_results: List[ScoringResultRow], metrics: List[AggregationFunctionType]
) -> Dict[str, Any]:
agg_results = {}
for metric in metrics:
if metric not in AGGREGATION_FUNCTIONS:
raise ValueError(f"Aggregation function {metric} not found")
agg_fn = AGGREGATION_FUNCTIONS[metric]
agg_results[metric] = agg_fn(scoring_results)
return agg_results

View file

@ -8,11 +8,12 @@ from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
class BaseScoringFn(ABC):
"""
Base interface class for all meta-reference scoring_fns.
Base interface class for all native scoring_fns.
Each scoring_fn needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scoring_fn_results)
@ -44,11 +45,27 @@ class BaseScoringFn(ABC):
) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
async def aggregate(
self, scoring_results: List[ScoringResultRow]
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
raise NotImplementedError()
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
if scoring_params is not None:
if params is None:
params = scoring_params
else:
params.aggregation_functions = scoring_params.aggregation_functions
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")
and params.aggregation_functions
):
aggregation_functions.extend(params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)
async def score(
self,