From fbc3888fd7a4c116b03c0e12df75448917c44f50 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 10 Dec 2024 15:46:46 -0800 Subject: [PATCH] aggregation function config --- .../scoring_functions/scoring_functions.py | 27 ++++++++ .../providers/inline/scoring/basic/scoring.py | 2 +- .../basic/scoring_fn/equality_scoring_fn.py | 24 +++++-- .../scoring_fn/regex_parser_scoring_fn.py | 26 ++++++-- .../basic/scoring_fn/subset_of_scoring_fn.py | 22 +++++-- .../inline/scoring/llm_as_judge/scoring.py | 2 +- .../scoring_fn/llm_as_judge_scoring_fn.py | 5 +- .../providers/tests/scoring/test_scoring.py | 65 ++++++++++++++++++- .../utils/scoring/aggregation_utils.py | 38 ++++++++++- .../utils/scoring/base_scoring_fn.py | 4 +- 10 files changed, 189 insertions(+), 26 deletions(-) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 4dce5a46d..fc57cfbbf 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -31,6 +31,15 @@ from llama_stack.apis.resource import Resource, ResourceType class ScoringFnParamsType(Enum): llm_as_judge = "llm_as_judge" regex_parser = "regex_parser" + basic = "basic" + + +@json_schema_type +class AggregationFunctionType(Enum): + average = "average" + median = "median" + categorical_count = "categorical_count" + accuracy = "accuracy" @json_schema_type @@ -44,6 +53,10 @@ class LLMAsJudgeScoringFnParams(BaseModel): description="Regexes to extract the answer from generated response", default_factory=list, ) + aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + description="Aggregation functions to apply to the scores of each row", + default_factory=list, + ) @json_schema_type @@ -55,12 +68,26 @@ class RegexParserScoringFnParams(BaseModel): description="Regex to extract the answer from generated response", default_factory=list, ) + aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + description="Aggregation functions to apply to the scores of each row", + default_factory=list, + ) + + +@json_schema_type +class BasicScoringFnParams(BaseModel): + type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value + aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + description="Aggregation functions to apply to the scores of each row", + default_factory=list, + ) ScoringFnParams = Annotated[ Union[ LLMAsJudgeScoringFnParams, RegexParserScoringFnParams, + BasicScoringFnParams, ], Field(discriminator="type"), ] diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index ac8f8630f..273cf4b3c 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -113,7 +113,7 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): score_results = await scoring_fn.score( input_rows, scoring_fn_id, scoring_fn_params ) - agg_results = await scoring_fn.aggregate(score_results) + agg_results = await scoring_fn.aggregate(score_results, scoring_fn_params) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py index 7eba4a21b..e36177ecb 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py @@ -4,12 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 -from llama_stack.apis.common.type_system import * # noqa: F403 +from typing import Any, Dict, List, Optional -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy +from llama_stack.apis.scoring import ScoringResultRow + +from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics +from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from .fn_defs.equality import equality @@ -44,6 +45,15 @@ class EqualityScoringFn(BaseScoringFn): } async def aggregate( - self, scoring_results: List[ScoringResultRow] + self, + scoring_results: List[ScoringResultRow], + scoring_params: Optional[ScoringFnParams] = None, ) -> Dict[str, Any]: - return aggregate_accuracy(scoring_results) + aggregation_functions = [AggregationFunctionType.accuracy] + if ( + scoring_params + and hasattr(scoring_params, "aggregation_functions") + and scoring_params.aggregation_functions + ): + aggregation_functions.extend(scoring_params.aggregation_functions) + return aggregate_metrics(scoring_results, aggregation_functions) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index fd036ced1..67ca62a96 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -5,11 +5,16 @@ # the root directory of this source tree. import re +from typing import Any, Dict, List, Optional + +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + ScoringFnParams, + ScoringFnParamsType, +) +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy from .fn_defs.regex_parser_multiple_choice_answer import ( regex_parser_multiple_choice_answer, @@ -62,6 +67,15 @@ class RegexParserScoringFn(BaseScoringFn): } async def aggregate( - self, scoring_results: List[ScoringResultRow] + self, + scoring_results: List[ScoringResultRow], + scoring_params: Optional[ScoringFnParams] = None, ) -> Dict[str, Any]: - return aggregate_accuracy(scoring_results) + aggregation_functions = [AggregationFunctionType.accuracy] + if ( + scoring_params + and hasattr(scoring_params, "aggregation_functions") + and scoring_params.aggregation_functions + ): + aggregation_functions.extend(scoring_params.aggregation_functions) + return aggregate_metrics(scoring_results, aggregation_functions) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py index 1ff3c9b1c..3da978f16 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py @@ -4,11 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any, Dict, List, Optional + +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFnParams +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy from .fn_defs.subset_of import subset_of @@ -38,6 +39,15 @@ class SubsetOfScoringFn(BaseScoringFn): } async def aggregate( - self, scoring_results: List[ScoringResultRow] + self, + scoring_results: List[ScoringResultRow], + scoring_params: Optional[ScoringFnParams] = None, ) -> Dict[str, Any]: - return aggregate_accuracy(scoring_results) + aggregation_functions = [AggregationFunctionType.accuracy] + if ( + scoring_params + and hasattr(scoring_params, "aggregation_functions") + and scoring_params.aggregation_functions + ): + aggregation_functions.extend(scoring_params.aggregation_functions) + return aggregate_metrics(scoring_results, aggregation_functions) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index 33462631c..3f63241a5 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -120,7 +120,7 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): score_results = await scoring_fn.score( input_rows, scoring_fn_id, scoring_fn_params ) - agg_results = await scoring_fn.aggregate(score_results) + agg_results = await scoring_fn.aggregate(score_results, scoring_fn_params) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index 3f4df3304..e51d512e4 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -87,7 +87,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn): } async def aggregate( - self, scoring_results: List[ScoringResultRow] + self, + scoring_results: List[ScoringResultRow], + scoring_params: Optional[ScoringFnParams] = None, ) -> Dict[str, Any]: + print(f"scoring_params: {scoring_params}") # TODO: this needs to be config based aggregation, and only useful w/ Jobs API return {} diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 08a05681f..8fc41878b 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -7,7 +7,12 @@ import pytest -from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + LLMAsJudgeScoringFnParams, + RegexParserScoringFnParams, +) from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset @@ -129,7 +134,7 @@ class TestScoring: assert len(rows.rows) == 3 scoring_functions = { - "llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams( + "llm-as-judge::base": LLMAsJudgeScoringFnParams( judge_model="Llama3.1-405B-Instruct", prompt_template="Output a number response in the following format: Score: , where is the number between 0 and 9.", judge_score_regexes=[r"Score: (\d+)"], @@ -154,3 +159,59 @@ class TestScoring: for x in scoring_functions: assert x in response.results assert len(response.results[x].score_rows) == 5 + + @pytest.mark.asyncio + async def test_scoring_score_with_aggregation_functions(self, scoring_stack): + ( + scoring_impl, + scoring_functions_impl, + datasetio_impl, + datasets_impl, + models_impl, + ) = ( + scoring_stack[Api.scoring], + scoring_stack[Api.scoring_functions], + scoring_stack[Api.datasetio], + scoring_stack[Api.datasets], + scoring_stack[Api.models], + ) + await register_dataset(datasets_impl) + rows = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert len(rows.rows) == 3 + + scoring_fns_list = await scoring_functions_impl.list_scoring_functions() + provider_id = scoring_fns_list[0].provider_id + scoring_functions = {} + aggr_fns = [ + AggregationFunctionType.accuracy, + AggregationFunctionType.median, + AggregationFunctionType.categorical_count, + AggregationFunctionType.average, + ] + for x in scoring_fns_list: + if x.provider_id == "llm-as-judge": + scoring_functions[x.identifier] = LLMAsJudgeScoringFnParams( + aggregation_functions=[AggregationFunctionType.categorical_count], + ) + elif x.provider_id == "basic": + if "regex_parser" in x.identifier: + scoring_functions[x.identifier] = RegexParserScoringFnParams( + aggregation_functions=aggr_fns, + ) + else: + scoring_functions[x.identifier] = BasicScoringFnParams( + aggregation_functions=aggr_fns, + ) + else: + scoring_functions[x.identifier] = None + + response = await scoring_impl.score( + input_rows=rows.rows, + scoring_functions=scoring_functions, + ) + from rich.pretty import pprint + + pprint(response) diff --git a/llama_stack/providers/utils/scoring/aggregation_utils.py b/llama_stack/providers/utils/scoring/aggregation_utils.py index 1ca0c7fb3..7b9d58944 100644 --- a/llama_stack/providers/utils/scoring/aggregation_utils.py +++ b/llama_stack/providers/utils/scoring/aggregation_utils.py @@ -3,9 +3,10 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import statistics from typing import Any, Dict, List -from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: @@ -26,3 +27,38 @@ def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any] ) / len([_ for _ in scoring_results if _["score"] is not None]), } + + +def aggregate_categorical_count( + scoring_results: List[ScoringResultRow], +) -> Dict[str, Any]: + scores = [str(r["score"]) for r in scoring_results] + unique_scores = sorted(list(set(scores))) + return {"categorical_count": {s: scores.count(s) for s in unique_scores}} + + +def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + scores = [r["score"] for r in scoring_results if r["score"] is not None] + median = statistics.median(scores) if scores else None + return {"median": median} + + +# TODO: decide whether we want to make aggregation functions as a registerable resource +AGGREGATION_FUNCTIONS = { + AggregationFunctionType.accuracy: aggregate_accuracy, + AggregationFunctionType.average: aggregate_average, + AggregationFunctionType.categorical_count: aggregate_categorical_count, + AggregationFunctionType.median: aggregate_median, +} + + +def aggregate_metrics( + scoring_results: List[ScoringResultRow], metrics: List[AggregationFunctionType] +) -> Dict[str, Any]: + agg_results = {} + for metric in metrics: + if metric not in AGGREGATION_FUNCTIONS: + raise ValueError(f"Aggregation function {metric} not found") + agg_fn = AGGREGATION_FUNCTIONS[metric] + agg_results[metric] = agg_fn(scoring_results) + return agg_results diff --git a/llama_stack/providers/utils/scoring/base_scoring_fn.py b/llama_stack/providers/utils/scoring/base_scoring_fn.py index 8cd101c50..ef3edc451 100644 --- a/llama_stack/providers/utils/scoring/base_scoring_fn.py +++ b/llama_stack/providers/utils/scoring/base_scoring_fn.py @@ -46,7 +46,9 @@ class BaseScoringFn(ABC): @abstractmethod async def aggregate( - self, scoring_results: List[ScoringResultRow] + self, + scoring_results: List[ScoringResultRow], + scoring_params: Optional[ScoringFnParams] = None, ) -> Dict[str, Any]: raise NotImplementedError()