update scoring test

This commit is contained in:
Xi Yan 2024-10-23 17:22:48 -07:00
parent 3c6555c408
commit 7c803cef86
3 changed files with 24 additions and 9 deletions

View file

@ -14,7 +14,6 @@ async def get_provider_impl(
config: MetaReferenceScoringConfig, config: MetaReferenceScoringConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, ProviderSpec],
): ):
print("get_provider_impl", deps)
from .scoring import MetaReferenceScoringImpl from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets]) impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])

View file

@ -12,8 +12,6 @@ from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from termcolor import cprint
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import ( from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
EqualityScorer, EqualityScorer,
@ -38,7 +36,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
self.datasets_api = datasets_api self.datasets_api = datasets_api
cprint(f"!!! MetaReferenceScoringImpl init {config} {datasets_api}", "red")
async def initialize(self) -> None: ... async def initialize(self) -> None: ...

View file

@ -10,6 +10,7 @@ from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from llama_stack.providers.tests.resolver import resolve_impls_for_test from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test: # How to run this test:
@ -36,14 +37,32 @@ async def scoring_settings():
return { return {
"scoring_impl": impls[Api.scoring], "scoring_impl": impls[Api.scoring],
"scoring_functions_impl": impls[Api.scoring_functions], "scoring_functions_impl": impls[Api.scoring_functions],
"datasets_impl": impls[Api.datasets],
} }
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scoring_functions_list(scoring_settings): async def test_scoring_functions_list(scoring_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
scoring_functions_impl = scoring_settings["scoring_functions_impl"] scoring_functions_impl = scoring_settings["scoring_functions_impl"]
response = await scoring_functions_impl.list_scoring_functions() scoring_functions = await scoring_functions_impl.list_scoring_functions()
assert isinstance(response, list) assert isinstance(scoring_functions, list)
assert len(response) == 0 assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions]
assert "equality" in function_ids
@pytest.mark.asyncio
async def test_scoring_score(scoring_settings):
scoring_impl = scoring_settings["scoring_impl"]
datasets_impl = scoring_settings["datasets_impl"]
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
response = await scoring_impl.score_batch(
dataset_id=response[0].identifier,
scoring_functions=["equality"],
)
assert len(response.results) == 1