diff --git a/llama_stack/apis/scoring/client.py b/llama_stack/apis/scoring/client.py index 0e3998275..24f6400b4 100644 --- a/llama_stack/apis/scoring/client.py +++ b/llama_stack/apis/scoring/client.py @@ -36,7 +36,10 @@ class ScoringClient(Scoring): async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/scoring/score_batch", - params={}, + json={ + "dataset_id": dataset_id, + "scoring_functions": scoring_functions, + }, headers={"Content-Type": "application/json"}, timeout=60, ) @@ -44,7 +47,7 @@ class ScoringClient(Scoring): if not response.json(): return - return ScoreResponse(**response.json()) + return ScoreBatchResponse(**response.json()) async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] @@ -112,6 +115,14 @@ async def run_main(host: str, port: int): ) cprint(f"scoring response={response}", "blue") + # test scoring batch using datasetio api + scoring_client = ScoringClient(f"http://{host}:{port}") + response = await scoring_client.score_batch( + dataset_id="test-dataset", + scoring_functions=["equality"], + ) + cprint(f"scoring response={response}", "blue") + def main(host: str, port: int): asyncio.run(run_main(host, port)) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index c85ba47d0..f398c7f25 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -211,8 +211,15 @@ class ScoringRouter(Scoring): async def score_batch( self, dataset_id: str, scoring_functions: List[str] ) -> ScoreBatchResponse: - # TODO - pass + print("Score Batch!") + for fn_identifier in scoring_functions: + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score_batch( + dataset_id=dataset_id, + scoring_functions=[fn_identifier], + ) + print(score_response) async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] diff --git a/llama_stack/providers/impls/meta_reference/scoring/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/__init__.py index 48e177324..d1b6b371c 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/__init__.py +++ b/llama_stack/providers/impls/meta_reference/scoring/__init__.py @@ -17,6 +17,6 @@ async def get_provider_impl( print("get_provider_impl", deps) from .scoring import MetaReferenceScoringImpl - impl = MetaReferenceScoringImpl(config, deps[Api.datasetio]) + impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets]) await impl.initialize() return impl diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 17763413d..315fd887a 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -10,6 +10,7 @@ from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 from termcolor import cprint @@ -29,7 +30,10 @@ SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORE class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): def __init__( - self, config: MetaReferenceScoringConfig, datasetio_api: DatasetIO + self, + config: MetaReferenceScoringConfig, + datasetio_api: DatasetIO, + datasets_api: Datasets, ) -> None: self.config = config self.datasetio_api = datasetio_api @@ -50,7 +54,15 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score_batch( self, dataset_id: str, scoring_functions: List[str] ) -> ScoreBatchResponse: - print("score_batch") + rows_paginated = await self.datasetio_api.get_rows_paginated( + dataset_id=dataset_id, + rows_in_page=-1, + ) + res = await self.score( + input_rows=rows_paginated.rows, scoring_functions=scoring_functions + ) + + cprint(f"res: {res}", "green") async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index 69af25839..4543449b4 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -19,6 +19,7 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.impls.meta_reference.scoring.MetaReferenceScoringConfig", api_dependencies=[ Api.datasetio, + Api.datasets, ], ), ]