From 7c803cef8635d1bb114a37fe8cbb09e02ddafe76 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 23 Oct 2024 17:22:48 -0700 Subject: [PATCH] update scoring test --- .../impls/meta_reference/scoring/__init__.py | 1 - .../impls/meta_reference/scoring/scoring.py | 3 -- .../providers/tests/scoring/test_scoring.py | 29 +++++++++++++++---- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/scoring/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/__init__.py index d1b6b371c..69d9b543a 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/__init__.py +++ b/llama_stack/providers/impls/meta_reference/scoring/__init__.py @@ -14,7 +14,6 @@ async def get_provider_impl( config: MetaReferenceScoringConfig, deps: Dict[Api, ProviderSpec], ): - print("get_provider_impl", deps) from .scoring import MetaReferenceScoringImpl impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets]) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 0015e19bc..8b07bcbd8 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -12,8 +12,6 @@ from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 -from termcolor import cprint - from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import ( EqualityScorer, @@ -38,7 +36,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets_api - cprint(f"!!! MetaReferenceScoringImpl init {config} {datasets_api}", "red") async def initialize(self) -> None: ... diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index dccfc78fc..1af5c05cf 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -10,6 +10,7 @@ from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset from llama_stack.providers.tests.resolver import resolve_impls_for_test # How to run this test: @@ -36,14 +37,32 @@ async def scoring_settings(): return { "scoring_impl": impls[Api.scoring], "scoring_functions_impl": impls[Api.scoring_functions], + "datasets_impl": impls[Api.datasets], } @pytest.mark.asyncio async def test_scoring_functions_list(scoring_settings): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful scoring_functions_impl = scoring_settings["scoring_functions_impl"] - response = await scoring_functions_impl.list_scoring_functions() - assert isinstance(response, list) - assert len(response) == 0 + scoring_functions = await scoring_functions_impl.list_scoring_functions() + assert isinstance(scoring_functions, list) + assert len(scoring_functions) > 0 + function_ids = [f.identifier for f in scoring_functions] + assert "equality" in function_ids + + +@pytest.mark.asyncio +async def test_scoring_score(scoring_settings): + scoring_impl = scoring_settings["scoring_impl"] + datasets_impl = scoring_settings["datasets_impl"] + await register_dataset(datasets_impl) + + response = await datasets_impl.list_datasets() + assert len(response) == 1 + + response = await scoring_impl.score_batch( + dataset_id=response[0].identifier, + scoring_functions=["equality"], + ) + + assert len(response.results) == 1