diff --git a/llama_stack/apis/scoring/client.py b/llama_stack/apis/scoring/client.py index 24f6400b4..f08fa4bc0 100644 --- a/llama_stack/apis/scoring/client.py +++ b/llama_stack/apis/scoring/client.py @@ -113,7 +113,7 @@ async def run_main(host: str, port: int): input_rows=response.rows, scoring_functions=["equality"], ) - cprint(f"scoring response={response}", "blue") + cprint(f"score response={response}", "blue") # test scoring batch using datasetio api scoring_client = ScoringClient(f"http://{host}:{port}") @@ -121,7 +121,7 @@ async def run_main(host: str, port: int): dataset_id="test-dataset", scoring_functions=["equality"], ) - cprint(f"scoring response={response}", "blue") + cprint(f"score_batch response={response}", "cyan") def main(host: str, port: int): diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index edb36c75e..54d8f7487 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -18,7 +18,8 @@ ScoringResult = Dict[str, Any] @json_schema_type class ScoreBatchResponse(BaseModel): - dataset_id: str + dataset_id: Optional[str] = None + results: Dict[str, ScoringResult] @json_schema_type @@ -37,7 +38,10 @@ class Scoring(Protocol): @webmethod(route="/scoring/score_batch") async def score_batch( - self, dataset_id: str, scoring_functions: List[str] + self, + dataset_id: str, + scoring_functions: List[str], + save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @webmethod(route="/scoring/score") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index f398c7f25..348d8449d 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -209,9 +209,12 @@ class ScoringRouter(Scoring): pass async def score_batch( - self, dataset_id: str, scoring_functions: List[str] + self, + dataset_id: str, + scoring_functions: List[str], + save_results_dataset: bool = False, ) -> ScoreBatchResponse: - print("Score Batch!") + res = {} for fn_identifier in scoring_functions: score_response = await self.routing_table.get_provider_impl( fn_identifier @@ -219,7 +222,14 @@ class ScoringRouter(Scoring): dataset_id=dataset_id, scoring_functions=[fn_identifier], ) - print(score_response) + res.update(score_response.results) + + if save_results_dataset: + raise NotImplementedError("Save results dataset not implemented yet") + + return ScoreBatchResponse( + results=res, + ) async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 315fd887a..0015e19bc 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -37,7 +37,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): ) -> None: self.config = config self.datasetio_api = datasetio_api - cprint(f"!!! MetaReferenceScoringImpl init {config} {datasetio_api}", "red") + self.datasets_api = datasets_api + cprint(f"!!! MetaReferenceScoringImpl init {config} {datasets_api}", "red") async def initialize(self) -> None: ... @@ -52,7 +53,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): ) async def score_batch( - self, dataset_id: str, scoring_functions: List[str] + self, + dataset_id: str, + scoring_functions: List[str], + save_results_dataset: bool = False, ) -> ScoreBatchResponse: rows_paginated = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, @@ -61,8 +65,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): res = await self.score( input_rows=rows_paginated.rows, scoring_functions=scoring_functions ) + if save_results_dataset: + # TODO: persist and register dataset on to server for reading + # self.datasets_api.register_dataset() + raise NotImplementedError("Save results dataset not implemented yet") - cprint(f"res: {res}", "green") + return ScoreBatchResponse( + results=res.results, + ) async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]