scoring fix

This commit is contained in:
Xi Yan 2024-11-06 18:07:16 -08:00
parent c5cf9c30be
commit 56239fce90
10 changed files with 104 additions and 15 deletions

View file

@ -212,6 +212,7 @@ class ScoringRouter(Scoring):
self,
dataset_id: str,
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
res = {}
@ -221,6 +222,7 @@ class ScoringRouter(Scoring):
).score_batch(
dataset_id=dataset_id,
scoring_functions=[fn_identifier],
scoring_params=scoring_params,
)
res.update(score_response.results)
@ -232,7 +234,10 @@ class ScoringRouter(Scoring):
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
# look up and map each scoring function to its provider impl
@ -242,6 +247,7 @@ class ScoringRouter(Scoring):
).score(
input_rows=input_rows,
scoring_functions=[fn_identifier],
scoring_params=scoring_params,
)
res.update(score_response.results)