diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 14e311cfc..9a9a29439 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -4926,6 +4926,15 @@ "config" ] }, + "AggregationFunctionType": { + "type": "string", + "enum": [ + "average", + "median", + "categorical_count", + "accuracy" + ] + }, "AppEvalTaskConfig": { "type": "object", "properties": { @@ -4953,6 +4962,9 @@ }, { "$ref": "#/components/schemas/RegexParserScoringFnParams" + }, + { + "$ref": "#/components/schemas/BasicScoringFnParams" } ] } @@ -4968,6 +4980,26 @@ "scoring_params" ] }, + "BasicScoringFnParams": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "basic", + "default": "basic" + }, + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + } + } + }, + "additionalProperties": false, + "required": [ + "type" + ] + }, "BenchmarkEvalTaskConfig": { "type": "object", "properties": { @@ -5015,6 +5047,12 @@ "items": { "type": "string" } + }, + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + } } }, "additionalProperties": false, @@ -5061,6 +5099,12 @@ "items": { "type": "string" } + }, + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + } } }, "additionalProperties": false, @@ -6014,6 +6058,9 @@ }, { "$ref": "#/components/schemas/RegexParserScoringFnParams" + }, + { + "$ref": "#/components/schemas/BasicScoringFnParams" } ] } @@ -7771,6 +7818,9 @@ }, { "$ref": "#/components/schemas/RegexParserScoringFnParams" + }, + { + "$ref": "#/components/schemas/BasicScoringFnParams" } ] } @@ -7998,6 +8048,9 @@ }, { "$ref": "#/components/schemas/RegexParserScoringFnParams" + }, + { + "$ref": "#/components/schemas/BasicScoringFnParams" } ] }, @@ -8046,6 +8099,9 @@ }, { "$ref": "#/components/schemas/RegexParserScoringFnParams" + }, + { + "$ref": "#/components/schemas/BasicScoringFnParams" } ] }, @@ -8491,6 +8547,10 @@ { "name": "Agents" }, + { + "name": "AggregationFunctionType", + "description": "" + }, { "name": "AppEvalTaskConfig", "description": "" @@ -8503,6 +8563,10 @@ "name": "Attachment", "description": "" }, + { + "name": "BasicScoringFnParams", + "description": "" + }, { "name": "BatchChatCompletionRequest", "description": "" @@ -9146,9 +9210,11 @@ "AgentTurnResponseStreamChunk", "AgentTurnResponseTurnCompletePayload", "AgentTurnResponseTurnStartPayload", + "AggregationFunctionType", "AppEvalTaskConfig", "AppendRowsRequest", "Attachment", + "BasicScoringFnParams", "BatchChatCompletionRequest", "BatchChatCompletionResponse", "BatchCompletionRequest", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 86fcae23d..a1cd08387 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -216,6 +216,13 @@ components: - event_type - turn_id type: object + AggregationFunctionType: + enum: + - average + - median + - categorical_count + - accuracy + type: string AppEvalTaskConfig: additionalProperties: false properties: @@ -230,6 +237,7 @@ components: oneOf: - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/RegexParserScoringFnParams' + - $ref: '#/components/schemas/BasicScoringFnParams' type: object type: const: app @@ -280,6 +288,20 @@ components: - content - mime_type type: object + BasicScoringFnParams: + additionalProperties: false + properties: + aggregation_functions: + items: + $ref: '#/components/schemas/AggregationFunctionType' + type: array + type: + const: basic + default: basic + type: string + required: + - type + type: object BatchChatCompletionRequest: additionalProperties: false properties: @@ -1280,6 +1302,10 @@ components: LLMAsJudgeScoringFnParams: additionalProperties: false properties: + aggregation_functions: + items: + $ref: '#/components/schemas/AggregationFunctionType' + type: array judge_model: type: string judge_score_regexes: @@ -1984,6 +2010,10 @@ components: RegexParserScoringFnParams: additionalProperties: false properties: + aggregation_functions: + items: + $ref: '#/components/schemas/AggregationFunctionType' + type: array parsing_regexes: items: type: string @@ -2195,6 +2225,7 @@ components: oneOf: - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/RegexParserScoringFnParams' + - $ref: '#/components/schemas/BasicScoringFnParams' provider_id: type: string provider_scoring_fn_id: @@ -2515,6 +2546,7 @@ components: - oneOf: - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/RegexParserScoringFnParams' + - $ref: '#/components/schemas/BasicScoringFnParams' - type: 'null' type: object required: @@ -2555,6 +2587,7 @@ components: - oneOf: - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/RegexParserScoringFnParams' + - $ref: '#/components/schemas/BasicScoringFnParams' - type: 'null' type: object required: @@ -2592,6 +2625,7 @@ components: oneOf: - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/RegexParserScoringFnParams' + - $ref: '#/components/schemas/BasicScoringFnParams' provider_id: type: string provider_resource_id: @@ -5161,6 +5195,9 @@ tags: /> name: AgentTurnResponseTurnStartPayload - name: Agents +- description: + name: AggregationFunctionType - description: name: AppEvalTaskConfig @@ -5169,6 +5206,9 @@ tags: name: AppendRowsRequest - description: name: Attachment +- description: + name: BasicScoringFnParams - description: name: BatchChatCompletionRequest @@ -5636,9 +5676,11 @@ x-tagGroups: - AgentTurnResponseStreamChunk - AgentTurnResponseTurnCompletePayload - AgentTurnResponseTurnStartPayload + - AggregationFunctionType - AppEvalTaskConfig - AppendRowsRequest - Attachment + - BasicScoringFnParams - BatchChatCompletionRequest - BatchChatCompletionResponse - BatchCompletionRequest diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 4dce5a46d..fc57cfbbf 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -31,6 +31,15 @@ from llama_stack.apis.resource import Resource, ResourceType class ScoringFnParamsType(Enum): llm_as_judge = "llm_as_judge" regex_parser = "regex_parser" + basic = "basic" + + +@json_schema_type +class AggregationFunctionType(Enum): + average = "average" + median = "median" + categorical_count = "categorical_count" + accuracy = "accuracy" @json_schema_type @@ -44,6 +53,10 @@ class LLMAsJudgeScoringFnParams(BaseModel): description="Regexes to extract the answer from generated response", default_factory=list, ) + aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + description="Aggregation functions to apply to the scores of each row", + default_factory=list, + ) @json_schema_type @@ -55,12 +68,26 @@ class RegexParserScoringFnParams(BaseModel): description="Regex to extract the answer from generated response", default_factory=list, ) + aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + description="Aggregation functions to apply to the scores of each row", + default_factory=list, + ) + + +@json_schema_type +class BasicScoringFnParams(BaseModel): + type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value + aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + description="Aggregation functions to apply to the scores of each row", + default_factory=list, + ) ScoringFnParams = Annotated[ Union[ LLMAsJudgeScoringFnParams, RegexParserScoringFnParams, + BasicScoringFnParams, ], Field(discriminator="type"), ] diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index ac8f8630f..0c0503ff5 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -113,7 +113,9 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): score_results = await scoring_fn.score( input_rows, scoring_fn_id, scoring_fn_params ) - agg_results = await scoring_fn.aggregate(score_results) + agg_results = await scoring_fn.aggregate( + score_results, scoring_fn_id, scoring_fn_params + ) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py index 7eba4a21b..9991c5502 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 -from llama_stack.apis.common.type_system import * # noqa: F403 +from typing import Any, Dict, Optional -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy +from llama_stack.apis.scoring import ScoringResultRow + +from llama_stack.apis.scoring_functions import ScoringFnParams +from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from .fn_defs.equality import equality @@ -42,8 +42,3 @@ class EqualityScoringFn(BaseScoringFn): return { "score": score, } - - async def aggregate( - self, scoring_results: List[ScoringResultRow] - ) -> Dict[str, Any]: - return aggregate_accuracy(scoring_results) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py index 8403119f6..c20171829 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py @@ -5,14 +5,20 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) equality = ScoringFn( identifier="basic::equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", - params=None, provider_id="basic", provider_resource_id="equality", return_type=NumberType(), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.accuracy] + ), ) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py index 9d028a468..b7a649a48 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py @@ -4,9 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + RegexParserScoringFnParams, + ScoringFn, +) MULTILINGUAL_ANSWER_REGEXES = [ r"Answer\s*:", @@ -67,5 +70,6 @@ regex_parser_multiple_choice_answer = ScoringFn( MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) for x in MULTILINGUAL_ANSWER_REGEXES ], + aggregation_functions=[AggregationFunctionType.accuracy], ), ) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py index ab2a9c60b..98f54afb5 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py @@ -5,7 +5,11 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + ScoringFn, +) subset_of = ScoringFn( @@ -14,4 +18,7 @@ subset_of = ScoringFn( return_type=NumberType(), provider_id="basic", provider_resource_id="subset-of", + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.accuracy] + ), ) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index fd036ced1..552f34d46 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -5,11 +5,11 @@ # the root directory of this source tree. import re +from typing import Any, Dict, Optional + +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy from .fn_defs.regex_parser_multiple_choice_answer import ( regex_parser_multiple_choice_answer, @@ -60,8 +60,3 @@ class RegexParserScoringFn(BaseScoringFn): return { "score": score, } - - async def aggregate( - self, scoring_results: List[ScoringResultRow] - ) -> Dict[str, Any]: - return aggregate_accuracy(scoring_results) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py index 1ff3c9b1c..29ae12e44 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py @@ -4,11 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any, Dict, Optional + +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy from .fn_defs.subset_of import subset_of @@ -36,8 +36,3 @@ class SubsetOfScoringFn(BaseScoringFn): return { "score": score, } - - async def aggregate( - self, scoring_results: List[ScoringResultRow] - ) -> Dict[str, Any]: - return aggregate_accuracy(scoring_results) diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index 8b22a8930..ae9555403 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -147,7 +147,7 @@ class BraintrustScoringImpl( await self.score_row(input_row, scoring_fn_id) for input_row in input_rows ] - + aggregation_functions = [AggregationFunctionType.average] agg_results = aggregate_average(score_results) res[scoring_fn_id] = ScoringResult( score_rows=score_results, diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index 33462631c..09780e6fb 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -120,7 +120,9 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): score_results = await scoring_fn.score( input_rows, scoring_fn_id, scoring_fn_params ) - agg_results = await scoring_fn.aggregate(score_results) + agg_results = await scoring_fn.aggregate( + score_results, scoring_fn_id, scoring_fn_params + ) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index 3f4df3304..00ea53c8f 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -3,13 +3,16 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import re + +from typing import Any, Dict, Optional + from llama_stack.apis.inference.inference import Inference +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import ScoringFnParams + from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 -from llama_stack.apis.common.type_system import * # noqa: F403 -import re from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa @@ -85,9 +88,3 @@ class LlmAsJudgeScoringFn(BaseScoringFn): "score": judge_rating, "judge_feedback": content, } - - async def aggregate( - self, scoring_results: List[ScoringResultRow] - ) -> Dict[str, Any]: - # TODO: this needs to be config based aggregation, and only useful w/ Jobs API - return {} diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 08a05681f..846d30cbb 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -7,7 +7,12 @@ import pytest -from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.scoring_functions import ( + AggregationFunctionType, + BasicScoringFnParams, + LLMAsJudgeScoringFnParams, + RegexParserScoringFnParams, +) from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset @@ -18,6 +23,11 @@ from llama_stack.providers.tests.datasetio.test_datasetio import register_datase # -v -s --tb=short --disable-warnings +@pytest.fixture +def sample_judge_prompt_template(): + return "Output a number response in the following format: Score: , where is the number between 0 and 9." + + class TestScoring: @pytest.mark.asyncio async def test_scoring_functions_list(self, scoring_stack): @@ -92,7 +102,9 @@ class TestScoring: assert len(response.results[x].score_rows) == 5 @pytest.mark.asyncio - async def test_scoring_score_with_params(self, scoring_stack): + async def test_scoring_score_with_params_llm_as_judge( + self, scoring_stack, sample_judge_prompt_template + ): ( scoring_impl, scoring_functions_impl, @@ -129,10 +141,11 @@ class TestScoring: assert len(rows.rows) == 3 scoring_functions = { - "llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams( + "llm-as-judge::base": LLMAsJudgeScoringFnParams( judge_model="Llama3.1-405B-Instruct", - prompt_template="Output a number response in the following format: Score: , where is the number between 0 and 9.", + prompt_template=sample_judge_prompt_template, judge_score_regexes=[r"Score: (\d+)"], + aggregation_functions=[AggregationFunctionType.categorical_count], ) } @@ -154,3 +167,67 @@ class TestScoring: for x in scoring_functions: assert x in response.results assert len(response.results[x].score_rows) == 5 + + @pytest.mark.asyncio + async def test_scoring_score_with_aggregation_functions( + self, scoring_stack, sample_judge_prompt_template + ): + ( + scoring_impl, + scoring_functions_impl, + datasetio_impl, + datasets_impl, + models_impl, + ) = ( + scoring_stack[Api.scoring], + scoring_stack[Api.scoring_functions], + scoring_stack[Api.datasetio], + scoring_stack[Api.datasets], + scoring_stack[Api.models], + ) + await register_dataset(datasets_impl) + rows = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert len(rows.rows) == 3 + + scoring_fns_list = await scoring_functions_impl.list_scoring_functions() + scoring_functions = {} + aggr_fns = [ + AggregationFunctionType.accuracy, + AggregationFunctionType.median, + AggregationFunctionType.categorical_count, + AggregationFunctionType.average, + ] + for x in scoring_fns_list: + if x.provider_id == "llm-as-judge": + aggr_fns = [AggregationFunctionType.categorical_count] + scoring_functions[x.identifier] = LLMAsJudgeScoringFnParams( + judge_model="Llama3.1-405B-Instruct", + prompt_template=sample_judge_prompt_template, + judge_score_regexes=[r"Score: (\d+)"], + aggregation_functions=aggr_fns, + ) + elif x.provider_id == "basic": + if "regex_parser" in x.identifier: + scoring_functions[x.identifier] = RegexParserScoringFnParams( + aggregation_functions=aggr_fns, + ) + else: + scoring_functions[x.identifier] = BasicScoringFnParams( + aggregation_functions=aggr_fns, + ) + else: + scoring_functions[x.identifier] = None + + response = await scoring_impl.score( + input_rows=rows.rows, + scoring_functions=scoring_functions, + ) + + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == len(rows.rows) + assert len(response.results[x].aggregated_results) == len(aggr_fns) diff --git a/llama_stack/providers/utils/scoring/aggregation_utils.py b/llama_stack/providers/utils/scoring/aggregation_utils.py index 1ca0c7fb3..7b9d58944 100644 --- a/llama_stack/providers/utils/scoring/aggregation_utils.py +++ b/llama_stack/providers/utils/scoring/aggregation_utils.py @@ -3,9 +3,10 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import statistics from typing import Any, Dict, List -from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: @@ -26,3 +27,38 @@ def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any] ) / len([_ for _ in scoring_results if _["score"] is not None]), } + + +def aggregate_categorical_count( + scoring_results: List[ScoringResultRow], +) -> Dict[str, Any]: + scores = [str(r["score"]) for r in scoring_results] + unique_scores = sorted(list(set(scores))) + return {"categorical_count": {s: scores.count(s) for s in unique_scores}} + + +def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + scores = [r["score"] for r in scoring_results if r["score"] is not None] + median = statistics.median(scores) if scores else None + return {"median": median} + + +# TODO: decide whether we want to make aggregation functions as a registerable resource +AGGREGATION_FUNCTIONS = { + AggregationFunctionType.accuracy: aggregate_accuracy, + AggregationFunctionType.average: aggregate_average, + AggregationFunctionType.categorical_count: aggregate_categorical_count, + AggregationFunctionType.median: aggregate_median, +} + + +def aggregate_metrics( + scoring_results: List[ScoringResultRow], metrics: List[AggregationFunctionType] +) -> Dict[str, Any]: + agg_results = {} + for metric in metrics: + if metric not in AGGREGATION_FUNCTIONS: + raise ValueError(f"Aggregation function {metric} not found") + agg_fn = AGGREGATION_FUNCTIONS[metric] + agg_results[metric] = agg_fn(scoring_results) + return agg_results diff --git a/llama_stack/providers/utils/scoring/base_scoring_fn.py b/llama_stack/providers/utils/scoring/base_scoring_fn.py index 8cd101c50..2db77fd2b 100644 --- a/llama_stack/providers/utils/scoring/base_scoring_fn.py +++ b/llama_stack/providers/utils/scoring/base_scoring_fn.py @@ -8,11 +8,12 @@ from typing import Any, Dict, List, Optional from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics class BaseScoringFn(ABC): """ - Base interface class for all meta-reference scoring_fns. + Base interface class for all native scoring_fns. Each scoring_fn needs to implement the following methods: - score_row(self, row) - aggregate(self, scoring_fn_results) @@ -44,11 +45,27 @@ class BaseScoringFn(ABC): ) -> ScoringResultRow: raise NotImplementedError() - @abstractmethod async def aggregate( - self, scoring_results: List[ScoringResultRow] + self, + scoring_results: List[ScoringResultRow], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, ) -> Dict[str, Any]: - raise NotImplementedError() + params = self.supported_fn_defs_registry[scoring_fn_identifier].params + if scoring_params is not None: + if params is None: + params = scoring_params + else: + params.aggregation_functions = scoring_params.aggregation_functions + + aggregation_functions = [] + if ( + params + and hasattr(params, "aggregation_functions") + and params.aggregation_functions + ): + aggregation_functions.extend(params.aggregation_functions) + return aggregate_metrics(scoring_results, aggregation_functions) async def score( self,