forked from phoenix-oss/llama-stack-mirror
test: revamp eval related integration tests (#1433)
# What does this PR do? - revamp and clean up datasets/scoring/eval integration tests - closes https://github.com/meta-llama/llama-stack/issues/1396 [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan **dataset** ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/integration/datasetio/ ``` <img width="842" alt="image" src="https://github.com/user-attachments/assets/88fc2b6a-b496-47bf-bc0c-8fea48ba36ff" /> **scoring** ``` LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/scoring --text-model meta-llama/Llama-3.1-8B-Instruct --judge-model meta-llama/Llama-3.1-8B-Instruct ``` <img width="851" alt="image" src="https://github.com/user-attachments/assets/50f46415-b44c-4c37-a6c3-076f2767adb3" /> **eval** ``` LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/eval --text-model meta-llama/Llama-3.1-8B-Instruct --judge-model meta-llama/Llama-3.1-8B-Instruct ``` <img width="841" alt="image" src="https://github.com/user-attachments/assets/8eb1c65c-3b39-4d66-8ff4-f471ca783e49" /> [//]: # (## Documentation)
This commit is contained in:
parent
82e94fe22f
commit
bcb13c492f
7 changed files with 184 additions and 222 deletions
|
@ -15,14 +15,70 @@ def sample_judge_prompt_template():
|
|||
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_scoring_fn_id():
|
||||
return "llm-as-judge-test-prompt"
|
||||
|
||||
|
||||
def register_scoring_function(
|
||||
llama_stack_client,
|
||||
provider_id,
|
||||
scoring_fn_id,
|
||||
judge_model_id,
|
||||
judge_prompt_template,
|
||||
):
|
||||
llama_stack_client.scoring_functions.register(
|
||||
scoring_fn_id=scoring_fn_id,
|
||||
provider_id=provider_id,
|
||||
description="LLM as judge scoring function with test prompt",
|
||||
return_type={
|
||||
"type": "string",
|
||||
},
|
||||
params={
|
||||
"type": "llm_as_judge",
|
||||
"judge_model": judge_model_id,
|
||||
"prompt_template": judge_prompt_template,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_scoring_functions_list(llama_stack_client):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
response = llama_stack_client.scoring_functions.list()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) > 0
|
||||
|
||||
|
||||
def test_scoring_functions_register(
|
||||
llama_stack_client,
|
||||
sample_scoring_fn_id,
|
||||
judge_model_id,
|
||||
sample_judge_prompt_template,
|
||||
):
|
||||
llm_as_judge_provider = [
|
||||
x
|
||||
for x in llama_stack_client.providers.list()
|
||||
if x.api == "scoring" and x.provider_type == "inline::llm-as-judge"
|
||||
]
|
||||
if len(llm_as_judge_provider) == 0:
|
||||
pytest.skip("No llm-as-judge provider found, cannot test registeration")
|
||||
|
||||
llm_as_judge_provider_id = llm_as_judge_provider[0].provider_id
|
||||
register_scoring_function(
|
||||
llama_stack_client,
|
||||
llm_as_judge_provider_id,
|
||||
sample_scoring_fn_id,
|
||||
judge_model_id,
|
||||
sample_judge_prompt_template,
|
||||
)
|
||||
|
||||
list_response = llama_stack_client.scoring_functions.list()
|
||||
assert isinstance(list_response, list)
|
||||
assert len(list_response) > 0
|
||||
assert any(x.identifier == sample_scoring_fn_id for x in list_response)
|
||||
|
||||
# TODO: add unregister api for scoring functions
|
||||
|
||||
|
||||
def test_scoring_score(llama_stack_client):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
response = llama_stack_client.datasets.list()
|
||||
|
@ -106,8 +162,17 @@ def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge
|
|||
assert len(response.results[x].score_rows) == 5
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping because this seems to be really slow")
|
||||
def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_judge_prompt_template, judge_model_id):
|
||||
@pytest.mark.parametrize(
|
||||
"provider_id",
|
||||
[
|
||||
"basic",
|
||||
"llm-as-judge",
|
||||
"braintrust",
|
||||
],
|
||||
)
|
||||
def test_scoring_score_with_aggregation_functions(
|
||||
llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id
|
||||
):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
|
@ -115,7 +180,10 @@ def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_jud
|
|||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_fns_list = llama_stack_client.scoring_functions.list()
|
||||
scoring_fns_list = [x for x in llama_stack_client.scoring_functions.list() if x.provider_id == provider_id]
|
||||
if len(scoring_fns_list) == 0:
|
||||
pytest.skip(f"No scoring functions found for provider {provider_id}, skipping")
|
||||
|
||||
scoring_functions = {}
|
||||
aggr_fns = [
|
||||
"accuracy",
|
||||
|
@ -123,30 +191,31 @@ def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_jud
|
|||
"categorical_count",
|
||||
"average",
|
||||
]
|
||||
for x in scoring_fns_list:
|
||||
if x.provider_id == "llm-as-judge":
|
||||
aggr_fns = ["categorical_count"]
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="llm_as_judge",
|
||||
judge_model=judge_model_id,
|
||||
prompt_template=sample_judge_prompt_template,
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
|
||||
scoring_fn = scoring_fns_list[0]
|
||||
if scoring_fn.provider_id == "llm-as-judge":
|
||||
aggr_fns = ["categorical_count"]
|
||||
scoring_functions[scoring_fn.identifier] = dict(
|
||||
type="llm_as_judge",
|
||||
judge_model=judge_model_id,
|
||||
prompt_template=sample_judge_prompt_template,
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
elif scoring_fn.provider_id == "basic" or scoring_fn.provider_id == "braintrust":
|
||||
if "regex_parser" in scoring_fn.identifier:
|
||||
scoring_functions[scoring_fn.identifier] = dict(
|
||||
type="regex_parser",
|
||||
parsing_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
elif x.provider_id == "basic" or x.provider_id == "braintrust":
|
||||
if "regex_parser" in x.identifier:
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="regex_parser",
|
||||
parsing_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
else:
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="basic",
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
else:
|
||||
scoring_functions[x.identifier] = None
|
||||
scoring_functions[scoring_fn.identifier] = dict(
|
||||
type="basic",
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
else:
|
||||
scoring_functions[scoring_fn.identifier] = None
|
||||
|
||||
response = llama_stack_client.scoring.score(
|
||||
input_rows=rows.rows,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue