mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
update scoring test
This commit is contained in:
parent
3c6555c408
commit
7c803cef86
3 changed files with 24 additions and 9 deletions
|
@ -14,7 +14,6 @@ async def get_provider_impl(
|
||||||
config: MetaReferenceScoringConfig,
|
config: MetaReferenceScoringConfig,
|
||||||
deps: Dict[Api, ProviderSpec],
|
deps: Dict[Api, ProviderSpec],
|
||||||
):
|
):
|
||||||
print("get_provider_impl", deps)
|
|
||||||
from .scoring import MetaReferenceScoringImpl
|
from .scoring import MetaReferenceScoringImpl
|
||||||
|
|
||||||
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
|
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
|
||||||
|
|
|
@ -12,8 +12,6 @@ from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
|
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
|
||||||
EqualityScorer,
|
EqualityScorer,
|
||||||
|
@ -38,7 +36,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets_api
|
self.datasets_api = datasets_api
|
||||||
cprint(f"!!! MetaReferenceScoringImpl init {config} {datasets_api}", "red")
|
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
|
@ -36,14 +37,32 @@ async def scoring_settings():
|
||||||
return {
|
return {
|
||||||
"scoring_impl": impls[Api.scoring],
|
"scoring_impl": impls[Api.scoring],
|
||||||
"scoring_functions_impl": impls[Api.scoring_functions],
|
"scoring_functions_impl": impls[Api.scoring_functions],
|
||||||
|
"datasets_impl": impls[Api.datasets],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_scoring_functions_list(scoring_settings):
|
async def test_scoring_functions_list(scoring_settings):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
|
||||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||||
response = await scoring_functions_impl.list_scoring_functions()
|
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||||
assert isinstance(response, list)
|
assert isinstance(scoring_functions, list)
|
||||||
assert len(response) == 0
|
assert len(scoring_functions) > 0
|
||||||
|
function_ids = [f.identifier for f in scoring_functions]
|
||||||
|
assert "equality" in function_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scoring_score(scoring_settings):
|
||||||
|
scoring_impl = scoring_settings["scoring_impl"]
|
||||||
|
datasets_impl = scoring_settings["datasets_impl"]
|
||||||
|
await register_dataset(datasets_impl)
|
||||||
|
|
||||||
|
response = await datasets_impl.list_datasets()
|
||||||
|
assert len(response) == 1
|
||||||
|
|
||||||
|
response = await scoring_impl.score_batch(
|
||||||
|
dataset_id=response[0].identifier,
|
||||||
|
scoring_functions=["equality"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.results) == 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue