fix scoring

This commit is contained in:
Xi Yan 2025-03-05 16:20:11 -08:00
parent 546a417b09
commit 5d43b9157e

View file

@ -162,8 +162,17 @@ def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge
assert len(response.results[x].score_rows) == 5 assert len(response.results[x].score_rows) == 5
@pytest.mark.skip(reason="Skipping because this seems to be really slow") @pytest.mark.parametrize(
def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_judge_prompt_template, judge_model_id): "provider_id",
[
"basic",
"llm-as-judge",
"braintrust",
],
)
def test_scoring_score_with_aggregation_functions(
llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id
):
register_dataset(llama_stack_client, for_rag=True) register_dataset(llama_stack_client, for_rag=True)
rows = llama_stack_client.datasetio.get_rows_paginated( rows = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset", dataset_id="test_dataset",
@ -171,7 +180,10 @@ def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_jud
) )
assert len(rows.rows) == 3 assert len(rows.rows) == 3
scoring_fns_list = llama_stack_client.scoring_functions.list() scoring_fns_list = [x for x in llama_stack_client.scoring_functions.list() if x.provider_id == provider_id]
if len(scoring_fns_list) == 0:
pytest.skip(f"No scoring functions found for provider {provider_id}, skipping")
scoring_functions = {} scoring_functions = {}
aggr_fns = [ aggr_fns = [
"accuracy", "accuracy",
@ -179,30 +191,31 @@ def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_jud
"categorical_count", "categorical_count",
"average", "average",
] ]
for x in scoring_fns_list:
if x.provider_id == "llm-as-judge": scoring_fn = scoring_fns_list[0]
if scoring_fn.provider_id == "llm-as-judge":
aggr_fns = ["categorical_count"] aggr_fns = ["categorical_count"]
scoring_functions[x.identifier] = dict( scoring_functions[scoring_fn.identifier] = dict(
type="llm_as_judge", type="llm_as_judge",
judge_model=judge_model_id, judge_model=judge_model_id,
prompt_template=sample_judge_prompt_template, prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"], judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns, aggregation_functions=aggr_fns,
) )
elif x.provider_id == "basic" or x.provider_id == "braintrust": elif scoring_fn.provider_id == "basic" or scoring_fn.provider_id == "braintrust":
if "regex_parser" in x.identifier: if "regex_parser" in scoring_fn.identifier:
scoring_functions[x.identifier] = dict( scoring_functions[scoring_fn.identifier] = dict(
type="regex_parser", type="regex_parser",
parsing_regexes=[r"Score: (\d+)"], parsing_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns, aggregation_functions=aggr_fns,
) )
else: else:
scoring_functions[x.identifier] = dict( scoring_functions[scoring_fn.identifier] = dict(
type="basic", type="basic",
aggregation_functions=aggr_fns, aggregation_functions=aggr_fns,
) )
else: else:
scoring_functions[x.identifier] = None scoring_functions[scoring_fn.identifier] = None
response = llama_stack_client.scoring.score( response = llama_stack_client.scoring.score(
input_rows=rows.rows, input_rows=rows.rows,