diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 14e311cfc..9a9a29439 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -4926,6 +4926,15 @@
"config"
]
},
+ "AggregationFunctionType": {
+ "type": "string",
+ "enum": [
+ "average",
+ "median",
+ "categorical_count",
+ "accuracy"
+ ]
+ },
"AppEvalTaskConfig": {
"type": "object",
"properties": {
@@ -4953,6 +4962,9 @@
},
{
"$ref": "#/components/schemas/RegexParserScoringFnParams"
+ },
+ {
+ "$ref": "#/components/schemas/BasicScoringFnParams"
}
]
}
@@ -4968,6 +4980,26 @@
"scoring_params"
]
},
+ "BasicScoringFnParams": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "basic",
+ "default": "basic"
+ },
+ "aggregation_functions": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/AggregationFunctionType"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type"
+ ]
+ },
"BenchmarkEvalTaskConfig": {
"type": "object",
"properties": {
@@ -5015,6 +5047,12 @@
"items": {
"type": "string"
}
+ },
+ "aggregation_functions": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/AggregationFunctionType"
+ }
}
},
"additionalProperties": false,
@@ -5061,6 +5099,12 @@
"items": {
"type": "string"
}
+ },
+ "aggregation_functions": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/AggregationFunctionType"
+ }
}
},
"additionalProperties": false,
@@ -6014,6 +6058,9 @@
},
{
"$ref": "#/components/schemas/RegexParserScoringFnParams"
+ },
+ {
+ "$ref": "#/components/schemas/BasicScoringFnParams"
}
]
}
@@ -7771,6 +7818,9 @@
},
{
"$ref": "#/components/schemas/RegexParserScoringFnParams"
+ },
+ {
+ "$ref": "#/components/schemas/BasicScoringFnParams"
}
]
}
@@ -7998,6 +8048,9 @@
},
{
"$ref": "#/components/schemas/RegexParserScoringFnParams"
+ },
+ {
+ "$ref": "#/components/schemas/BasicScoringFnParams"
}
]
},
@@ -8046,6 +8099,9 @@
},
{
"$ref": "#/components/schemas/RegexParserScoringFnParams"
+ },
+ {
+ "$ref": "#/components/schemas/BasicScoringFnParams"
}
]
},
@@ -8491,6 +8547,10 @@
{
"name": "Agents"
},
+ {
+ "name": "AggregationFunctionType",
+ "description": ""
+ },
{
"name": "AppEvalTaskConfig",
"description": ""
@@ -8503,6 +8563,10 @@
"name": "Attachment",
"description": ""
},
+ {
+ "name": "BasicScoringFnParams",
+ "description": ""
+ },
{
"name": "BatchChatCompletionRequest",
"description": ""
@@ -9146,9 +9210,11 @@
"AgentTurnResponseStreamChunk",
"AgentTurnResponseTurnCompletePayload",
"AgentTurnResponseTurnStartPayload",
+ "AggregationFunctionType",
"AppEvalTaskConfig",
"AppendRowsRequest",
"Attachment",
+ "BasicScoringFnParams",
"BatchChatCompletionRequest",
"BatchChatCompletionResponse",
"BatchCompletionRequest",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 86fcae23d..a1cd08387 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -216,6 +216,13 @@ components:
- event_type
- turn_id
type: object
+ AggregationFunctionType:
+ enum:
+ - average
+ - median
+ - categorical_count
+ - accuracy
+ type: string
AppEvalTaskConfig:
additionalProperties: false
properties:
@@ -230,6 +237,7 @@ components:
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
+ - $ref: '#/components/schemas/BasicScoringFnParams'
type: object
type:
const: app
@@ -280,6 +288,20 @@ components:
- content
- mime_type
type: object
+ BasicScoringFnParams:
+ additionalProperties: false
+ properties:
+ aggregation_functions:
+ items:
+ $ref: '#/components/schemas/AggregationFunctionType'
+ type: array
+ type:
+ const: basic
+ default: basic
+ type: string
+ required:
+ - type
+ type: object
BatchChatCompletionRequest:
additionalProperties: false
properties:
@@ -1280,6 +1302,10 @@ components:
LLMAsJudgeScoringFnParams:
additionalProperties: false
properties:
+ aggregation_functions:
+ items:
+ $ref: '#/components/schemas/AggregationFunctionType'
+ type: array
judge_model:
type: string
judge_score_regexes:
@@ -1984,6 +2010,10 @@ components:
RegexParserScoringFnParams:
additionalProperties: false
properties:
+ aggregation_functions:
+ items:
+ $ref: '#/components/schemas/AggregationFunctionType'
+ type: array
parsing_regexes:
items:
type: string
@@ -2195,6 +2225,7 @@ components:
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
+ - $ref: '#/components/schemas/BasicScoringFnParams'
provider_id:
type: string
provider_scoring_fn_id:
@@ -2515,6 +2546,7 @@ components:
- oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
+ - $ref: '#/components/schemas/BasicScoringFnParams'
- type: 'null'
type: object
required:
@@ -2555,6 +2587,7 @@ components:
- oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
+ - $ref: '#/components/schemas/BasicScoringFnParams'
- type: 'null'
type: object
required:
@@ -2592,6 +2625,7 @@ components:
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
+ - $ref: '#/components/schemas/BasicScoringFnParams'
provider_id:
type: string
provider_resource_id:
@@ -5161,6 +5195,9 @@ tags:
/>
name: AgentTurnResponseTurnStartPayload
- name: Agents
+- description:
+ name: AggregationFunctionType
- description:
name: AppEvalTaskConfig
@@ -5169,6 +5206,9 @@ tags:
name: AppendRowsRequest
- description:
name: Attachment
+- description:
+ name: BasicScoringFnParams
- description:
name: BatchChatCompletionRequest
@@ -5636,9 +5676,11 @@ x-tagGroups:
- AgentTurnResponseStreamChunk
- AgentTurnResponseTurnCompletePayload
- AgentTurnResponseTurnStartPayload
+ - AggregationFunctionType
- AppEvalTaskConfig
- AppendRowsRequest
- Attachment
+ - BasicScoringFnParams
- BatchChatCompletionRequest
- BatchChatCompletionResponse
- BatchCompletionRequest
diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py
index 4dce5a46d..fc57cfbbf 100644
--- a/llama_stack/apis/scoring_functions/scoring_functions.py
+++ b/llama_stack/apis/scoring_functions/scoring_functions.py
@@ -31,6 +31,15 @@ from llama_stack.apis.resource import Resource, ResourceType
class ScoringFnParamsType(Enum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
+ basic = "basic"
+
+
+@json_schema_type
+class AggregationFunctionType(Enum):
+ average = "average"
+ median = "median"
+ categorical_count = "categorical_count"
+ accuracy = "accuracy"
@json_schema_type
@@ -44,6 +53,10 @@ class LLMAsJudgeScoringFnParams(BaseModel):
description="Regexes to extract the answer from generated response",
default_factory=list,
)
+ aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
+ description="Aggregation functions to apply to the scores of each row",
+ default_factory=list,
+ )
@json_schema_type
@@ -55,12 +68,26 @@ class RegexParserScoringFnParams(BaseModel):
description="Regex to extract the answer from generated response",
default_factory=list,
)
+ aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
+ description="Aggregation functions to apply to the scores of each row",
+ default_factory=list,
+ )
+
+
+@json_schema_type
+class BasicScoringFnParams(BaseModel):
+ type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
+ aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
+ description="Aggregation functions to apply to the scores of each row",
+ default_factory=list,
+ )
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
+ BasicScoringFnParams,
],
Field(discriminator="type"),
]
diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py
index ac8f8630f..0c0503ff5 100644
--- a/llama_stack/providers/inline/scoring/basic/scoring.py
+++ b/llama_stack/providers/inline/scoring/basic/scoring.py
@@ -113,7 +113,9 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
- agg_results = await scoring_fn.aggregate(score_results)
+ agg_results = await scoring_fn.aggregate(
+ score_results, scoring_fn_id, scoring_fn_params
+ )
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py
index 7eba4a21b..9991c5502 100644
--- a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py
+++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py
@@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
-from llama_stack.apis.scoring_functions import * # noqa: F401, F403
-from llama_stack.apis.scoring import * # noqa: F401, F403
-from llama_stack.apis.common.type_system import * # noqa: F403
+from typing import Any, Dict, Optional
-from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
+from llama_stack.apis.scoring import ScoringResultRow
+
+from llama_stack.apis.scoring_functions import ScoringFnParams
+from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from .fn_defs.equality import equality
@@ -42,8 +42,3 @@ class EqualityScoringFn(BaseScoringFn):
return {
"score": score,
}
-
- async def aggregate(
- self, scoring_results: List[ScoringResultRow]
- ) -> Dict[str, Any]:
- return aggregate_accuracy(scoring_results)
diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py
index 8403119f6..c20171829 100644
--- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py
+++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py
@@ -5,14 +5,20 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
-from llama_stack.apis.scoring_functions import ScoringFn
+from llama_stack.apis.scoring_functions import (
+ AggregationFunctionType,
+ BasicScoringFnParams,
+ ScoringFn,
+)
equality = ScoringFn(
identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
- params=None,
provider_id="basic",
provider_resource_id="equality",
return_type=NumberType(),
+ params=BasicScoringFnParams(
+ aggregation_functions=[AggregationFunctionType.accuracy]
+ ),
)
diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py
index 9d028a468..b7a649a48 100644
--- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py
+++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py
@@ -4,9 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from llama_stack.apis.scoring_functions import * # noqa: F401, F403
-from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType
+from llama_stack.apis.scoring_functions import (
+ AggregationFunctionType,
+ RegexParserScoringFnParams,
+ ScoringFn,
+)
MULTILINGUAL_ANSWER_REGEXES = [
r"Answer\s*:",
@@ -67,5 +70,6 @@ regex_parser_multiple_choice_answer = ScoringFn(
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES
],
+ aggregation_functions=[AggregationFunctionType.accuracy],
),
)
diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py
index ab2a9c60b..98f54afb5 100644
--- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py
+++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py
@@ -5,7 +5,11 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
-from llama_stack.apis.scoring_functions import ScoringFn
+from llama_stack.apis.scoring_functions import (
+ AggregationFunctionType,
+ BasicScoringFnParams,
+ ScoringFn,
+)
subset_of = ScoringFn(
@@ -14,4 +18,7 @@ subset_of = ScoringFn(
return_type=NumberType(),
provider_id="basic",
provider_resource_id="subset-of",
+ params=BasicScoringFnParams(
+ aggregation_functions=[AggregationFunctionType.accuracy]
+ ),
)
diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py
index fd036ced1..552f34d46 100644
--- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py
+++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py
@@ -5,11 +5,11 @@
# the root directory of this source tree.
import re
+from typing import Any, Dict, Optional
+
+from llama_stack.apis.scoring import ScoringResultRow
+from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
-from llama_stack.apis.scoring_functions import * # noqa: F401, F403
-from llama_stack.apis.scoring import * # noqa: F401, F403
-from llama_stack.apis.common.type_system import * # noqa: F403
-from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
@@ -60,8 +60,3 @@ class RegexParserScoringFn(BaseScoringFn):
return {
"score": score,
}
-
- async def aggregate(
- self, scoring_results: List[ScoringResultRow]
- ) -> Dict[str, Any]:
- return aggregate_accuracy(scoring_results)
diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py
index 1ff3c9b1c..29ae12e44 100644
--- a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py
+++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py
@@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+from typing import Any, Dict, Optional
+
+from llama_stack.apis.scoring import ScoringResultRow
+from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
-from llama_stack.apis.scoring_functions import * # noqa: F401, F403
-from llama_stack.apis.scoring import * # noqa: F401, F403
-from llama_stack.apis.common.type_system import * # noqa: F403
-from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.subset_of import subset_of
@@ -36,8 +36,3 @@ class SubsetOfScoringFn(BaseScoringFn):
return {
"score": score,
}
-
- async def aggregate(
- self, scoring_results: List[ScoringResultRow]
- ) -> Dict[str, Any]:
- return aggregate_accuracy(scoring_results)
diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py
index 8b22a8930..ae9555403 100644
--- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py
+++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py
@@ -147,7 +147,7 @@ class BraintrustScoringImpl(
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
-
+ aggregation_functions = [AggregationFunctionType.average]
agg_results = aggregate_average(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py
index 33462631c..09780e6fb 100644
--- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py
+++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py
@@ -120,7 +120,9 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
- agg_results = await scoring_fn.aggregate(score_results)
+ agg_results = await scoring_fn.aggregate(
+ score_results, scoring_fn_id, scoring_fn_params
+ )
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py
index 3f4df3304..00ea53c8f 100644
--- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py
+++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py
@@ -3,13 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+import re
+
+from typing import Any, Dict, Optional
+
from llama_stack.apis.inference.inference import Inference
+from llama_stack.apis.scoring import ScoringResultRow
+from llama_stack.apis.scoring_functions import ScoringFnParams
+
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
-from llama_stack.apis.scoring_functions import * # noqa: F401, F403
-from llama_stack.apis.scoring import * # noqa: F401, F403
-from llama_stack.apis.common.type_system import * # noqa: F403
-import re
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
@@ -85,9 +88,3 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
"score": judge_rating,
"judge_feedback": content,
}
-
- async def aggregate(
- self, scoring_results: List[ScoringResultRow]
- ) -> Dict[str, Any]:
- # TODO: this needs to be config based aggregation, and only useful w/ Jobs API
- return {}
diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py
index 08a05681f..846d30cbb 100644
--- a/llama_stack/providers/tests/scoring/test_scoring.py
+++ b/llama_stack/providers/tests/scoring/test_scoring.py
@@ -7,7 +7,12 @@
import pytest
-from llama_stack.apis.scoring_functions import * # noqa: F403
+from llama_stack.apis.scoring_functions import (
+ AggregationFunctionType,
+ BasicScoringFnParams,
+ LLMAsJudgeScoringFnParams,
+ RegexParserScoringFnParams,
+)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
@@ -18,6 +23,11 @@ from llama_stack.providers.tests.datasetio.test_datasetio import register_datase
# -v -s --tb=short --disable-warnings
+@pytest.fixture
+def sample_judge_prompt_template():
+ return "Output a number response in the following format: Score: , where is the number between 0 and 9."
+
+
class TestScoring:
@pytest.mark.asyncio
async def test_scoring_functions_list(self, scoring_stack):
@@ -92,7 +102,9 @@ class TestScoring:
assert len(response.results[x].score_rows) == 5
@pytest.mark.asyncio
- async def test_scoring_score_with_params(self, scoring_stack):
+ async def test_scoring_score_with_params_llm_as_judge(
+ self, scoring_stack, sample_judge_prompt_template
+ ):
(
scoring_impl,
scoring_functions_impl,
@@ -129,10 +141,11 @@ class TestScoring:
assert len(rows.rows) == 3
scoring_functions = {
- "llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams(
+ "llm-as-judge::base": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
- prompt_template="Output a number response in the following format: Score: , where is the number between 0 and 9.",
+ prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"],
+ aggregation_functions=[AggregationFunctionType.categorical_count],
)
}
@@ -154,3 +167,67 @@ class TestScoring:
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5
+
+ @pytest.mark.asyncio
+ async def test_scoring_score_with_aggregation_functions(
+ self, scoring_stack, sample_judge_prompt_template
+ ):
+ (
+ scoring_impl,
+ scoring_functions_impl,
+ datasetio_impl,
+ datasets_impl,
+ models_impl,
+ ) = (
+ scoring_stack[Api.scoring],
+ scoring_stack[Api.scoring_functions],
+ scoring_stack[Api.datasetio],
+ scoring_stack[Api.datasets],
+ scoring_stack[Api.models],
+ )
+ await register_dataset(datasets_impl)
+ rows = await datasetio_impl.get_rows_paginated(
+ dataset_id="test_dataset",
+ rows_in_page=3,
+ )
+ assert len(rows.rows) == 3
+
+ scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
+ scoring_functions = {}
+ aggr_fns = [
+ AggregationFunctionType.accuracy,
+ AggregationFunctionType.median,
+ AggregationFunctionType.categorical_count,
+ AggregationFunctionType.average,
+ ]
+ for x in scoring_fns_list:
+ if x.provider_id == "llm-as-judge":
+ aggr_fns = [AggregationFunctionType.categorical_count]
+ scoring_functions[x.identifier] = LLMAsJudgeScoringFnParams(
+ judge_model="Llama3.1-405B-Instruct",
+ prompt_template=sample_judge_prompt_template,
+ judge_score_regexes=[r"Score: (\d+)"],
+ aggregation_functions=aggr_fns,
+ )
+ elif x.provider_id == "basic":
+ if "regex_parser" in x.identifier:
+ scoring_functions[x.identifier] = RegexParserScoringFnParams(
+ aggregation_functions=aggr_fns,
+ )
+ else:
+ scoring_functions[x.identifier] = BasicScoringFnParams(
+ aggregation_functions=aggr_fns,
+ )
+ else:
+ scoring_functions[x.identifier] = None
+
+ response = await scoring_impl.score(
+ input_rows=rows.rows,
+ scoring_functions=scoring_functions,
+ )
+
+ assert len(response.results) == len(scoring_functions)
+ for x in scoring_functions:
+ assert x in response.results
+ assert len(response.results[x].score_rows) == len(rows.rows)
+ assert len(response.results[x].aggregated_results) == len(aggr_fns)
diff --git a/llama_stack/providers/utils/scoring/aggregation_utils.py b/llama_stack/providers/utils/scoring/aggregation_utils.py
index 1ca0c7fb3..7b9d58944 100644
--- a/llama_stack/providers/utils/scoring/aggregation_utils.py
+++ b/llama_stack/providers/utils/scoring/aggregation_utils.py
@@ -3,9 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+import statistics
from typing import Any, Dict, List
-from llama_stack.apis.scoring import ScoringResultRow
+from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
@@ -26,3 +27,38 @@ def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]
)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}
+
+
+def aggregate_categorical_count(
+ scoring_results: List[ScoringResultRow],
+) -> Dict[str, Any]:
+ scores = [str(r["score"]) for r in scoring_results]
+ unique_scores = sorted(list(set(scores)))
+ return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
+
+
+def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
+ scores = [r["score"] for r in scoring_results if r["score"] is not None]
+ median = statistics.median(scores) if scores else None
+ return {"median": median}
+
+
+# TODO: decide whether we want to make aggregation functions as a registerable resource
+AGGREGATION_FUNCTIONS = {
+ AggregationFunctionType.accuracy: aggregate_accuracy,
+ AggregationFunctionType.average: aggregate_average,
+ AggregationFunctionType.categorical_count: aggregate_categorical_count,
+ AggregationFunctionType.median: aggregate_median,
+}
+
+
+def aggregate_metrics(
+ scoring_results: List[ScoringResultRow], metrics: List[AggregationFunctionType]
+) -> Dict[str, Any]:
+ agg_results = {}
+ for metric in metrics:
+ if metric not in AGGREGATION_FUNCTIONS:
+ raise ValueError(f"Aggregation function {metric} not found")
+ agg_fn = AGGREGATION_FUNCTIONS[metric]
+ agg_results[metric] = agg_fn(scoring_results)
+ return agg_results
diff --git a/llama_stack/providers/utils/scoring/base_scoring_fn.py b/llama_stack/providers/utils/scoring/base_scoring_fn.py
index 8cd101c50..2db77fd2b 100644
--- a/llama_stack/providers/utils/scoring/base_scoring_fn.py
+++ b/llama_stack/providers/utils/scoring/base_scoring_fn.py
@@ -8,11 +8,12 @@ from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFn
+from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
class BaseScoringFn(ABC):
"""
- Base interface class for all meta-reference scoring_fns.
+ Base interface class for all native scoring_fns.
Each scoring_fn needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scoring_fn_results)
@@ -44,11 +45,27 @@ class BaseScoringFn(ABC):
) -> ScoringResultRow:
raise NotImplementedError()
- @abstractmethod
async def aggregate(
- self, scoring_results: List[ScoringResultRow]
+ self,
+ scoring_results: List[ScoringResultRow],
+ scoring_fn_identifier: Optional[str] = None,
+ scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
- raise NotImplementedError()
+ params = self.supported_fn_defs_registry[scoring_fn_identifier].params
+ if scoring_params is not None:
+ if params is None:
+ params = scoring_params
+ else:
+ params.aggregation_functions = scoring_params.aggregation_functions
+
+ aggregation_functions = []
+ if (
+ params
+ and hasattr(params, "aggregation_functions")
+ and params.aggregation_functions
+ ):
+ aggregation_functions.extend(params.aggregation_functions)
+ return aggregate_metrics(scoring_results, aggregation_functions)
async def score(
self,