scoring test pass

This commit is contained in:
Xi Yan 2024-11-06 17:27:55 -08:00
parent 0351072531
commit 0bce74402f
4 changed files with 32 additions and 10 deletions

View file

@ -37,7 +37,6 @@ SCORING_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def scoring_stack(request):
fixture_dict = request.param
print("!!!", fixture_dict)
providers = {}
provider_data = {}
@ -56,5 +55,6 @@ async def scoring_stack(request):
return (
impls[Api.scoring],
impls[Api.scoring_functions],
impls[Api.datasetio],
impls[Api.datasets],
)

View file

@ -21,14 +21,36 @@ class TestScoring:
async def test_scoring_functions_list(self, scoring_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, scoring_functions_impl, _ = scoring_stack
_, scoring_functions_impl, _, _ = scoring_stack
response = await scoring_functions_impl.list_scoring_functions()
assert isinstance(response, list)
assert len(response) > 0
@pytest.mark.asyncio
async def test_scoring_score(self, scoring_stack):
scoring_impl, scoring_functions_impl, datasets_impl = scoring_stack
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
scoring_stack
)
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
scoring_functions = [
"meta-reference::llm_as_judge_8b_correctness",
"meta-reference::equality",
]
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)