diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py index ee578f9b3..513180ef4 100644 --- a/llama_stack/providers/tests/scoring/conftest.py +++ b/llama_stack/providers/tests/scoring/conftest.py @@ -16,7 +16,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "scoring": "meta_reference", - "datasetio": "meta_reference", + "datasetio": "localfs", "inference": "fireworks", }, id="meta_reference_scoring_fireworks_inference", @@ -25,7 +25,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "scoring": "meta_reference", - "datasetio": "meta_reference", + "datasetio": "localfs", "inference": "together", }, id="meta_reference_scoring_together_inference", diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index 925f98779..96409d200 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -52,9 +52,4 @@ async def scoring_stack(request): provider_data, ) - return ( - impls[Api.scoring], - impls[Api.scoring_functions], - impls[Api.datasetio], - impls[Api.datasets], - ) + return impls diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 3c1b6554f..170073eeb 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -8,7 +8,7 @@ import pytest from llama_stack.apis.scoring_functions import * # noqa: F403 - +from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset # How to run this test: @@ -23,20 +23,36 @@ class TestScoring: async def test_scoring_functions_list(self, scoring_stack): # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful - _, scoring_functions_impl, _, _ = scoring_stack + scoring_functions_impl = scoring_stack[Api.scoring_functions] response = await scoring_functions_impl.list_scoring_functions() assert isinstance(response, list) assert len(response) > 0 @pytest.mark.asyncio async def test_scoring_score(self, scoring_stack): - scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = ( - scoring_stack + ( + scoring_impl, + scoring_functions_impl, + datasetio_impl, + datasets_impl, + models_impl, + ) = ( + scoring_stack[Api.scoring], + scoring_stack[Api.scoring_functions], + scoring_stack[Api.datasetio], + scoring_stack[Api.datasets], + scoring_stack[Api.models], ) await register_dataset(datasets_impl) response = await datasets_impl.list_datasets() assert len(response) == 1 + for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: + await models_impl.register_model( + model_id=model_id, + provider_id="", + ) + # scoring individual rows rows = await datasetio_impl.get_rows_paginated( dataset_id="test_dataset", @@ -69,13 +85,29 @@ class TestScoring: @pytest.mark.asyncio async def test_scoring_score_with_params(self, scoring_stack): - scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = ( - scoring_stack + ( + scoring_impl, + scoring_functions_impl, + datasetio_impl, + datasets_impl, + models_impl, + ) = ( + scoring_stack[Api.scoring], + scoring_stack[Api.scoring_functions], + scoring_stack[Api.datasetio], + scoring_stack[Api.datasets], + scoring_stack[Api.models], ) await register_dataset(datasets_impl) response = await datasets_impl.list_datasets() assert len(response) == 1 + for model_id in ["Llama3.1-405B-Instruct"]: + await models_impl.register_model( + model_id=model_id, + provider_id="", + ) + # scoring individual rows rows = await datasetio_impl.get_rows_paginated( dataset_id="test_dataset",