mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
fix scoring
This commit is contained in:
parent
546a417b09
commit
5d43b9157e
1 changed files with 37 additions and 24 deletions
|
@ -162,8 +162,17 @@ def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge
|
||||||
assert len(response.results[x].score_rows) == 5
|
assert len(response.results[x].score_rows) == 5
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Skipping because this seems to be really slow")
|
@pytest.mark.parametrize(
|
||||||
def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_judge_prompt_template, judge_model_id):
|
"provider_id",
|
||||||
|
[
|
||||||
|
"basic",
|
||||||
|
"llm-as-judge",
|
||||||
|
"braintrust",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_scoring_score_with_aggregation_functions(
|
||||||
|
llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id
|
||||||
|
):
|
||||||
register_dataset(llama_stack_client, for_rag=True)
|
register_dataset(llama_stack_client, for_rag=True)
|
||||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||||
dataset_id="test_dataset",
|
dataset_id="test_dataset",
|
||||||
|
@ -171,7 +180,10 @@ def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_jud
|
||||||
)
|
)
|
||||||
assert len(rows.rows) == 3
|
assert len(rows.rows) == 3
|
||||||
|
|
||||||
scoring_fns_list = llama_stack_client.scoring_functions.list()
|
scoring_fns_list = [x for x in llama_stack_client.scoring_functions.list() if x.provider_id == provider_id]
|
||||||
|
if len(scoring_fns_list) == 0:
|
||||||
|
pytest.skip(f"No scoring functions found for provider {provider_id}, skipping")
|
||||||
|
|
||||||
scoring_functions = {}
|
scoring_functions = {}
|
||||||
aggr_fns = [
|
aggr_fns = [
|
||||||
"accuracy",
|
"accuracy",
|
||||||
|
@ -179,30 +191,31 @@ def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_jud
|
||||||
"categorical_count",
|
"categorical_count",
|
||||||
"average",
|
"average",
|
||||||
]
|
]
|
||||||
for x in scoring_fns_list:
|
|
||||||
if x.provider_id == "llm-as-judge":
|
scoring_fn = scoring_fns_list[0]
|
||||||
aggr_fns = ["categorical_count"]
|
if scoring_fn.provider_id == "llm-as-judge":
|
||||||
scoring_functions[x.identifier] = dict(
|
aggr_fns = ["categorical_count"]
|
||||||
type="llm_as_judge",
|
scoring_functions[scoring_fn.identifier] = dict(
|
||||||
judge_model=judge_model_id,
|
type="llm_as_judge",
|
||||||
prompt_template=sample_judge_prompt_template,
|
judge_model=judge_model_id,
|
||||||
judge_score_regexes=[r"Score: (\d+)"],
|
prompt_template=sample_judge_prompt_template,
|
||||||
|
judge_score_regexes=[r"Score: (\d+)"],
|
||||||
|
aggregation_functions=aggr_fns,
|
||||||
|
)
|
||||||
|
elif scoring_fn.provider_id == "basic" or scoring_fn.provider_id == "braintrust":
|
||||||
|
if "regex_parser" in scoring_fn.identifier:
|
||||||
|
scoring_functions[scoring_fn.identifier] = dict(
|
||||||
|
type="regex_parser",
|
||||||
|
parsing_regexes=[r"Score: (\d+)"],
|
||||||
aggregation_functions=aggr_fns,
|
aggregation_functions=aggr_fns,
|
||||||
)
|
)
|
||||||
elif x.provider_id == "basic" or x.provider_id == "braintrust":
|
|
||||||
if "regex_parser" in x.identifier:
|
|
||||||
scoring_functions[x.identifier] = dict(
|
|
||||||
type="regex_parser",
|
|
||||||
parsing_regexes=[r"Score: (\d+)"],
|
|
||||||
aggregation_functions=aggr_fns,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
scoring_functions[x.identifier] = dict(
|
|
||||||
type="basic",
|
|
||||||
aggregation_functions=aggr_fns,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
scoring_functions[x.identifier] = None
|
scoring_functions[scoring_fn.identifier] = dict(
|
||||||
|
type="basic",
|
||||||
|
aggregation_functions=aggr_fns,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scoring_functions[scoring_fn.identifier] = None
|
||||||
|
|
||||||
response = llama_stack_client.scoring.score(
|
response = llama_stack_client.scoring.score(
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue