forked from phoenix-oss/llama-stack-mirror
refactor(test): move tools, evals, datasetio, scoring and post training tests (#1401)
All of the tests from `llama_stack/providers/tests/` are now moved to `tests/integration`. I converted the `tools`, `scoring` and `datasetio` tests to use API. However, `eval` and `post_training` proved to be a bit challenging to leaving those. I think `post_training` should be relatively straightforward also. As part of this, I noticed that `wolfram_alpha` tool wasn't added to some of our commonly used distros so I added it. I am going to remove a lot of code duplication from distros next so while this looks like a one-off right now, it will go away and be there uniformly for all distros.
This commit is contained in:
parent
dd0db8038b
commit
abfbaf3c1b
51 changed files with 471 additions and 1245 deletions
160
tests/integration/scoring/test_scoring.py
Normal file
160
tests/integration/scoring/test_scoring.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from ..datasetio.test_datasetio import register_dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_judge_prompt_template():
|
||||
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
|
||||
|
||||
|
||||
def test_scoring_functions_list(llama_stack_client):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
response = llama_stack_client.scoring_functions.list()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) > 0
|
||||
|
||||
|
||||
def test_scoring_score(llama_stack_client):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert len(response) == 1
|
||||
|
||||
# scoring individual rows
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_fns_list = llama_stack_client.scoring_functions.list()
|
||||
scoring_functions = {
|
||||
scoring_fns_list[0].identifier: None,
|
||||
}
|
||||
|
||||
response = llama_stack_client.scoring.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
|
||||
# score batch
|
||||
response = llama_stack_client.scoring.score_batch(
|
||||
dataset_id="test_dataset",
|
||||
scoring_functions=scoring_functions,
|
||||
save_results_dataset=False,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == 5
|
||||
|
||||
|
||||
def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert len(response) == 1
|
||||
|
||||
# scoring individual rows
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_functions = {
|
||||
"llm-as-judge::base": dict(
|
||||
type="llm_as_judge",
|
||||
judge_model=judge_model_id,
|
||||
prompt_template=sample_judge_prompt_template,
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=[
|
||||
"categorical_count",
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
response = llama_stack_client.scoring.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
|
||||
# score batch
|
||||
response = llama_stack_client.scoring.score_batch(
|
||||
dataset_id="test_dataset",
|
||||
scoring_functions=scoring_functions,
|
||||
save_results_dataset=False,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == 5
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping because this seems to be really slow")
|
||||
def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_judge_prompt_template, judge_model_id):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_fns_list = llama_stack_client.scoring_functions.list()
|
||||
scoring_functions = {}
|
||||
aggr_fns = [
|
||||
"accuracy",
|
||||
"median",
|
||||
"categorical_count",
|
||||
"average",
|
||||
]
|
||||
for x in scoring_fns_list:
|
||||
if x.provider_id == "llm-as-judge":
|
||||
aggr_fns = ["categorical_count"]
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="llm_as_judge",
|
||||
judge_model=judge_model_id,
|
||||
prompt_template=sample_judge_prompt_template,
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
elif x.provider_id == "basic" or x.provider_id == "braintrust":
|
||||
if "regex_parser" in x.identifier:
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="regex_parser",
|
||||
parsing_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
else:
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="basic",
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
else:
|
||||
scoring_functions[x.identifier] = None
|
||||
|
||||
response = llama_stack_client.scoring.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
assert len(response.results[x].aggregated_results) == len(aggr_fns)
|
Loading…
Add table
Add a link
Reference in a new issue