aggregation function config

This commit is contained in:
Xi Yan 2024-12-10 15:46:46 -08:00
parent e2054d53e4
commit fbc3888fd7
10 changed files with 189 additions and 26 deletions

View file

@ -7,7 +7,12 @@
import pytest
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
@ -129,7 +134,7 @@ class TestScoring:
assert len(rows.rows) == 3
scoring_functions = {
"llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams(
"llm-as-judge::base": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
judge_score_regexes=[r"Score: (\d+)"],
@ -154,3 +159,59 @@ class TestScoring:
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5
@pytest.mark.asyncio
async def test_scoring_score_with_aggregation_functions(self, scoring_stack):
(
scoring_impl,
scoring_functions_impl,
datasetio_impl,
datasets_impl,
models_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
await register_dataset(datasets_impl)
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
scoring_functions = {}
aggr_fns = [
AggregationFunctionType.accuracy,
AggregationFunctionType.median,
AggregationFunctionType.categorical_count,
AggregationFunctionType.average,
]
for x in scoring_fns_list:
if x.provider_id == "llm-as-judge":
scoring_functions[x.identifier] = LLMAsJudgeScoringFnParams(
aggregation_functions=[AggregationFunctionType.categorical_count],
)
elif x.provider_id == "basic":
if "regex_parser" in x.identifier:
scoring_functions[x.identifier] = RegexParserScoringFnParams(
aggregation_functions=aggr_fns,
)
else:
scoring_functions[x.identifier] = BasicScoringFnParams(
aggregation_functions=aggr_fns,
)
else:
scoring_functions[x.identifier] = None
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
from rich.pretty import pprint
pprint(response)