score batch

This commit is contained in:
Xi Yan 2024-10-23 16:38:00 -07:00
parent eb572faf6f
commit 3c6555c408
4 changed files with 34 additions and 10 deletions

View file

@ -37,7 +37,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
) -> None:
self.config = config
self.datasetio_api = datasetio_api
cprint(f"!!! MetaReferenceScoringImpl init {config} {datasetio_api}", "red")
self.datasets_api = datasets_api
cprint(f"!!! MetaReferenceScoringImpl init {config} {datasets_api}", "red")
async def initialize(self) -> None: ...
@ -52,7 +53,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
)
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
rows_paginated = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
@ -61,8 +65,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
res = await self.score(
input_rows=rows_paginated.rows, scoring_functions=scoring_functions
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
cprint(f"res: {res}", "green")
return ScoreBatchResponse(
results=res.results,
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]