fix scoring test

This commit is contained in:
Xi Yan 2024-11-11 15:33:56 -05:00
parent e27c6e3662
commit 68a4e6d00e
3 changed files with 41 additions and 14 deletions

View file

@ -16,7 +16,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"scoring": "meta_reference", "scoring": "meta_reference",
"datasetio": "meta_reference", "datasetio": "localfs",
"inference": "fireworks", "inference": "fireworks",
}, },
id="meta_reference_scoring_fireworks_inference", id="meta_reference_scoring_fireworks_inference",
@ -25,7 +25,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"scoring": "meta_reference", "scoring": "meta_reference",
"datasetio": "meta_reference", "datasetio": "localfs",
"inference": "together", "inference": "together",
}, },
id="meta_reference_scoring_together_inference", id="meta_reference_scoring_together_inference",

View file

@ -52,9 +52,4 @@ async def scoring_stack(request):
provider_data, provider_data,
) )
return ( return impls
impls[Api.scoring],
impls[Api.scoring_functions],
impls[Api.datasetio],
impls[Api.datasets],
)

View file

@ -8,7 +8,7 @@
import pytest import pytest
from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
# How to run this test: # How to run this test:
@ -23,20 +23,36 @@ class TestScoring:
async def test_scoring_functions_list(self, scoring_stack): async def test_scoring_functions_list(self, scoring_stack):
# NOTE: this needs you to ensure that you are starting from a clean state # NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful # but so far we don't have an unregister API unfortunately, so be careful
_, scoring_functions_impl, _, _ = scoring_stack scoring_functions_impl = scoring_stack[Api.scoring_functions]
response = await scoring_functions_impl.list_scoring_functions() response = await scoring_functions_impl.list_scoring_functions()
assert isinstance(response, list) assert isinstance(response, list)
assert len(response) > 0 assert len(response) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scoring_score(self, scoring_stack): async def test_scoring_score(self, scoring_stack):
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = ( (
scoring_stack scoring_impl,
scoring_functions_impl,
datasetio_impl,
datasets_impl,
models_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
scoring_stack[Api.models],
) )
await register_dataset(datasets_impl) await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
assert len(response) == 1 assert len(response) == 1
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
# scoring individual rows # scoring individual rows
rows = await datasetio_impl.get_rows_paginated( rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset", dataset_id="test_dataset",
@ -69,13 +85,29 @@ class TestScoring:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scoring_score_with_params(self, scoring_stack): async def test_scoring_score_with_params(self, scoring_stack):
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = ( (
scoring_stack scoring_impl,
scoring_functions_impl,
datasetio_impl,
datasets_impl,
models_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
scoring_stack[Api.models],
) )
await register_dataset(datasets_impl) await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
assert len(response) == 1 assert len(response) == 1
for model_id in ["Llama3.1-405B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
# scoring individual rows # scoring individual rows
rows = await datasetio_impl.get_rows_paginated( rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset", dataset_id="test_dataset",