forked from phoenix-oss/llama-stack-mirror
scoring fix
This commit is contained in:
parent
6f5df08ebf
commit
28b8c1c815
1 changed files with 33 additions and 64 deletions
|
@ -5,23 +5,13 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..datasetio.test_datasetio import register_dataset
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def rag_dataset_for_test(llama_stack_client):
|
|
||||||
dataset_id = "test_dataset"
|
|
||||||
register_dataset(llama_stack_client, for_rag=True, dataset_id=dataset_id)
|
|
||||||
yield # This is where the test function will run
|
|
||||||
|
|
||||||
# Teardown - this always runs, even if the test fails
|
|
||||||
try:
|
|
||||||
llama_stack_client.datasets.unregister(dataset_id)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Failed to unregister test_dataset: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_judge_prompt_template():
|
def sample_judge_prompt_template():
|
||||||
|
@ -92,49 +82,34 @@ def test_scoring_functions_register(
|
||||||
# TODO: add unregister api for scoring functions
|
# TODO: add unregister api for scoring functions
|
||||||
|
|
||||||
|
|
||||||
def test_scoring_score(llama_stack_client, rag_dataset_for_test):
|
@pytest.mark.parametrize("scoring_fn_id", ["basic::equality"])
|
||||||
|
def test_scoring_score(llama_stack_client, scoring_fn_id):
|
||||||
# scoring individual rows
|
# scoring individual rows
|
||||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
df = pd.read_csv(Path(__file__).parent.parent / "datasets" / "test_dataset.csv")
|
||||||
dataset_id="test_dataset",
|
rows = df.to_dict(orient="records")
|
||||||
rows_in_page=3,
|
|
||||||
)
|
|
||||||
assert len(rows.rows) == 3
|
|
||||||
|
|
||||||
scoring_fns_list = llama_stack_client.scoring_functions.list()
|
|
||||||
scoring_functions = {
|
scoring_functions = {
|
||||||
scoring_fns_list[0].identifier: None,
|
scoring_fn_id: None,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = llama_stack_client.scoring.score(
|
response = llama_stack_client.scoring.score(
|
||||||
input_rows=rows.rows,
|
input_rows=rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
)
|
)
|
||||||
assert len(response.results) == len(scoring_functions)
|
assert len(response.results) == len(scoring_functions)
|
||||||
for x in scoring_functions:
|
for x in scoring_functions:
|
||||||
assert x in response.results
|
assert x in response.results
|
||||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
assert len(response.results[x].score_rows) == len(rows)
|
||||||
|
|
||||||
# score batch
|
|
||||||
response = llama_stack_client.scoring.score_batch(
|
|
||||||
dataset_id="test_dataset",
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
save_results_dataset=False,
|
|
||||||
)
|
|
||||||
assert len(response.results) == len(scoring_functions)
|
|
||||||
for x in scoring_functions:
|
|
||||||
assert x in response.results
|
|
||||||
assert len(response.results[x].score_rows) == 5
|
|
||||||
|
|
||||||
|
|
||||||
def test_scoring_score_with_params_llm_as_judge(
|
def test_scoring_score_with_params_llm_as_judge(
|
||||||
llama_stack_client, sample_judge_prompt_template, judge_model_id, rag_dataset_for_test
|
llama_stack_client,
|
||||||
|
sample_judge_prompt_template,
|
||||||
|
judge_model_id,
|
||||||
):
|
):
|
||||||
# scoring individual rows
|
# scoring individual rows
|
||||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
df = pd.read_csv(Path(__file__).parent.parent / "datasets" / "test_dataset.csv")
|
||||||
dataset_id="test_dataset",
|
rows = df.to_dict(orient="records")
|
||||||
rows_in_page=3,
|
|
||||||
)
|
|
||||||
assert len(rows.rows) == 3
|
|
||||||
|
|
||||||
scoring_functions = {
|
scoring_functions = {
|
||||||
"llm-as-judge::base": dict(
|
"llm-as-judge::base": dict(
|
||||||
|
@ -149,24 +124,13 @@ def test_scoring_score_with_params_llm_as_judge(
|
||||||
}
|
}
|
||||||
|
|
||||||
response = llama_stack_client.scoring.score(
|
response = llama_stack_client.scoring.score(
|
||||||
input_rows=rows.rows,
|
input_rows=rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
)
|
)
|
||||||
assert len(response.results) == len(scoring_functions)
|
assert len(response.results) == len(scoring_functions)
|
||||||
for x in scoring_functions:
|
for x in scoring_functions:
|
||||||
assert x in response.results
|
assert x in response.results
|
||||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
assert len(response.results[x].score_rows) == len(rows)
|
||||||
|
|
||||||
# score batch
|
|
||||||
response = llama_stack_client.scoring.score_batch(
|
|
||||||
dataset_id="test_dataset",
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
save_results_dataset=False,
|
|
||||||
)
|
|
||||||
assert len(response.results) == len(scoring_functions)
|
|
||||||
for x in scoring_functions:
|
|
||||||
assert x in response.results
|
|
||||||
assert len(response.results[x].score_rows) == 5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -178,15 +142,20 @@ def test_scoring_score_with_params_llm_as_judge(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_scoring_score_with_aggregation_functions(
|
def test_scoring_score_with_aggregation_functions(
|
||||||
llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id, rag_dataset_for_test
|
llama_stack_client,
|
||||||
|
sample_judge_prompt_template,
|
||||||
|
judge_model_id,
|
||||||
|
provider_id,
|
||||||
|
rag_dataset_for_test,
|
||||||
):
|
):
|
||||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
df = pd.read_csv(Path(__file__).parent.parent / "datasets" / "test_dataset.csv")
|
||||||
dataset_id="test_dataset",
|
rows = df.to_dict(orient="records")
|
||||||
rows_in_page=3,
|
|
||||||
)
|
|
||||||
assert len(rows.rows) == 3
|
|
||||||
|
|
||||||
scoring_fns_list = [x for x in llama_stack_client.scoring_functions.list() if x.provider_id == provider_id]
|
scoring_fns_list = [
|
||||||
|
x
|
||||||
|
for x in llama_stack_client.scoring_functions.list()
|
||||||
|
if x.provider_id == provider_id
|
||||||
|
]
|
||||||
if len(scoring_fns_list) == 0:
|
if len(scoring_fns_list) == 0:
|
||||||
pytest.skip(f"No scoring functions found for provider {provider_id}, skipping")
|
pytest.skip(f"No scoring functions found for provider {provider_id}, skipping")
|
||||||
|
|
||||||
|
@ -224,12 +193,12 @@ def test_scoring_score_with_aggregation_functions(
|
||||||
scoring_functions[scoring_fn.identifier] = None
|
scoring_functions[scoring_fn.identifier] = None
|
||||||
|
|
||||||
response = llama_stack_client.scoring.score(
|
response = llama_stack_client.scoring.score(
|
||||||
input_rows=rows.rows,
|
input_rows=rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(response.results) == len(scoring_functions)
|
assert len(response.results) == len(scoring_functions)
|
||||||
for x in scoring_functions:
|
for x in scoring_functions:
|
||||||
assert x in response.results
|
assert x in response.results
|
||||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
assert len(response.results[x].score_rows) == len(rows)
|
||||||
assert len(response.results[x].aggregated_results) == len(aggr_fns)
|
assert len(response.results[x].aggregated_results) == len(aggr_fns)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue