forked from phoenix-oss/llama-stack-mirror
[/scoring] add ability to define aggregation functions for scoring functions & refactors (#597)
# What does this PR do? - Add ability to define aggregation functions for scoring functions via `ScoringFnParams` - Supported by `basic` / `regex_parser` / `llm_as_judge` scoring functions ## Test Plan ``` pytest -v -s -m basic_scoring_together_inference scoring/test_scoring.py ``` <img width="855" alt="image" src="https://github.com/user-attachments/assets/12db8e6e-2ad4-462e-b9b9-70ba6c050a6c"> ``` pytest -v -s -m llm_as_judge_scoring_together_inference scoring/test_scoring.py ``` <img width="858" alt="image" src="https://github.com/user-attachments/assets/bf806676-6f5e-456d-be9f-f81a26d1df19"> **Example Response** (`basic`) <img width="863" alt="image" src="https://github.com/user-attachments/assets/0e57a49c-8386-45cc-8fa9-3e61aaa9a3be"> **Example Response** (`llm-as-judge`) <img width="854" alt="image" src="https://github.com/user-attachments/assets/38065bc2-b724-47ed-9535-79b6099c4362"> ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
e128f2547a
commit
a4bcfb8bba
16 changed files with 323 additions and 55 deletions
|
@ -113,7 +113,9 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
agg_results = await scoring_fn.aggregate(score_results)
|
||||
agg_results = await scoring_fn.aggregate(
|
||||
score_results, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
|
||||
from .fn_defs.equality import equality
|
||||
|
||||
|
@ -42,8 +42,3 @@ class EqualityScoringFn(BaseScoringFn):
|
|||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
||||
|
|
|
@ -5,14 +5,20 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
equality = ScoringFn(
|
||||
identifier="basic::equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
params=None,
|
||||
provider_id="basic",
|
||||
provider_resource_id="equality",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.accuracy]
|
||||
),
|
||||
)
|
||||
|
|
|
@ -4,9 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
MULTILINGUAL_ANSWER_REGEXES = [
|
||||
r"Answer\s*:",
|
||||
|
@ -67,5 +70,6 @@ regex_parser_multiple_choice_answer = ScoringFn(
|
|||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||
for x in MULTILINGUAL_ANSWER_REGEXES
|
||||
],
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
||||
|
|
|
@ -5,7 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
subset_of = ScoringFn(
|
||||
|
@ -14,4 +18,7 @@ subset_of = ScoringFn(
|
|||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="subset-of",
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.accuracy]
|
||||
),
|
||||
)
|
||||
|
|
|
@ -5,11 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
import re
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
|
||||
|
||||
from .fn_defs.regex_parser_multiple_choice_answer import (
|
||||
regex_parser_multiple_choice_answer,
|
||||
|
@ -60,8 +60,3 @@ class RegexParserScoringFn(BaseScoringFn):
|
|||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
||||
|
|
|
@ -4,11 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
|
||||
|
||||
from .fn_defs.subset_of import subset_of
|
||||
|
||||
|
@ -36,8 +36,3 @@ class SubsetOfScoringFn(BaseScoringFn):
|
|||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue