From 5d43b9157e70db078dfdcff5494bff90ac43c53e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Mar 2025 16:20:11 -0800 Subject: [PATCH] fix scoring --- tests/integration/scoring/test_scoring.py | 61 ++++++++++++++--------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/tests/integration/scoring/test_scoring.py b/tests/integration/scoring/test_scoring.py index 2cb61303b..20dbfe5f6 100644 --- a/tests/integration/scoring/test_scoring.py +++ b/tests/integration/scoring/test_scoring.py @@ -162,8 +162,17 @@ def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge assert len(response.results[x].score_rows) == 5 -@pytest.mark.skip(reason="Skipping because this seems to be really slow") -def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_judge_prompt_template, judge_model_id): +@pytest.mark.parametrize( + "provider_id", + [ + "basic", + "llm-as-judge", + "braintrust", + ], +) +def test_scoring_score_with_aggregation_functions( + llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id +): register_dataset(llama_stack_client, for_rag=True) rows = llama_stack_client.datasetio.get_rows_paginated( dataset_id="test_dataset", @@ -171,7 +180,10 @@ def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_jud ) assert len(rows.rows) == 3 - scoring_fns_list = llama_stack_client.scoring_functions.list() + scoring_fns_list = [x for x in llama_stack_client.scoring_functions.list() if x.provider_id == provider_id] + if len(scoring_fns_list) == 0: + pytest.skip(f"No scoring functions found for provider {provider_id}, skipping") + scoring_functions = {} aggr_fns = [ "accuracy", @@ -179,30 +191,31 @@ def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_jud "categorical_count", "average", ] - for x in scoring_fns_list: - if x.provider_id == "llm-as-judge": - aggr_fns = ["categorical_count"] - scoring_functions[x.identifier] = dict( - type="llm_as_judge", - judge_model=judge_model_id, - prompt_template=sample_judge_prompt_template, - judge_score_regexes=[r"Score: (\d+)"], + + scoring_fn = scoring_fns_list[0] + if scoring_fn.provider_id == "llm-as-judge": + aggr_fns = ["categorical_count"] + scoring_functions[scoring_fn.identifier] = dict( + type="llm_as_judge", + judge_model=judge_model_id, + prompt_template=sample_judge_prompt_template, + judge_score_regexes=[r"Score: (\d+)"], + aggregation_functions=aggr_fns, + ) + elif scoring_fn.provider_id == "basic" or scoring_fn.provider_id == "braintrust": + if "regex_parser" in scoring_fn.identifier: + scoring_functions[scoring_fn.identifier] = dict( + type="regex_parser", + parsing_regexes=[r"Score: (\d+)"], aggregation_functions=aggr_fns, ) - elif x.provider_id == "basic" or x.provider_id == "braintrust": - if "regex_parser" in x.identifier: - scoring_functions[x.identifier] = dict( - type="regex_parser", - parsing_regexes=[r"Score: (\d+)"], - aggregation_functions=aggr_fns, - ) - else: - scoring_functions[x.identifier] = dict( - type="basic", - aggregation_functions=aggr_fns, - ) else: - scoring_functions[x.identifier] = None + scoring_functions[scoring_fn.identifier] = dict( + type="basic", + aggregation_functions=aggr_fns, + ) + else: + scoring_functions[scoring_fn.identifier] = None response = llama_stack_client.scoring.score( input_rows=rows.rows,