score batch

This commit is contained in:
Xi Yan 2024-10-23 16:38:00 -07:00
parent eb572faf6f
commit 3c6555c408
4 changed files with 34 additions and 10 deletions

View file

@ -113,7 +113,7 @@ async def run_main(host: str, port: int):
input_rows=response.rows,
scoring_functions=["equality"],
)
cprint(f"scoring response={response}", "blue")
cprint(f"score response={response}", "blue")
# test scoring batch using datasetio api
scoring_client = ScoringClient(f"http://{host}:{port}")
@ -121,7 +121,7 @@ async def run_main(host: str, port: int):
dataset_id="test-dataset",
scoring_functions=["equality"],
)
cprint(f"scoring response={response}", "blue")
cprint(f"score_batch response={response}", "cyan")
def main(host: str, port: int):

View file

@ -18,7 +18,8 @@ ScoringResult = Dict[str, Any]
@json_schema_type
class ScoreBatchResponse(BaseModel):
dataset_id: str
dataset_id: Optional[str] = None
results: Dict[str, ScoringResult]
@json_schema_type
@ -37,7 +38,10 @@ class Scoring(Protocol):
@webmethod(route="/scoring/score_batch")
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score")

View file

@ -209,9 +209,12 @@ class ScoringRouter(Scoring):
pass
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
print("Score Batch!")
res = {}
for fn_identifier in scoring_functions:
score_response = await self.routing_table.get_provider_impl(
fn_identifier
@ -219,7 +222,14 @@ class ScoringRouter(Scoring):
dataset_id=dataset_id,
scoring_functions=[fn_identifier],
)
print(score_response)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]

View file

@ -37,7 +37,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
) -> None:
self.config = config
self.datasetio_api = datasetio_api
cprint(f"!!! MetaReferenceScoringImpl init {config} {datasetio_api}", "red")
self.datasets_api = datasets_api
cprint(f"!!! MetaReferenceScoringImpl init {config} {datasets_api}", "red")
async def initialize(self) -> None: ...
@ -52,7 +53,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
)
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
rows_paginated = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
@ -61,8 +65,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
res = await self.score(
input_rows=rows_paginated.rows, scoring_functions=scoring_functions
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
cprint(f"res: {res}", "green")
return ScoreBatchResponse(
results=res.results,
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]