api refactor

This commit is contained in:
Xi Yan 2024-11-07 13:54:26 -08:00
parent 97dcd5704c
commit 51c20f9c29
8 changed files with 64 additions and 59 deletions

View file

@ -216,18 +216,16 @@ class ScoringRouter(Scoring):
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
res = {}
for fn_identifier in scoring_functions:
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id,
scoring_functions=[fn_identifier],
scoring_params=scoring_params,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
@ -241,18 +239,16 @@ class ScoringRouter(Scoring):
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions:
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
input_rows=input_rows,
scoring_functions=[fn_identifier],
scoring_params=scoring_params,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
@ -272,24 +268,24 @@ class EvalRouter(Eval):
async def shutdown(self) -> None:
pass
async def run_benchmark_eval(
async def run_benchmark(
self,
benchmark_id: str,
eval_task_config: BenchmarkEvalTaskConfig,
benchmark_config: BenchmarkEvalTaskConfig,
) -> Job:
pass
async def run_eval(
self,
eval_task_def: EvalTaskDef,
eval_task_config: EvalTaskConfig,
task: EvalTaskDef,
task_config: AppEvalTaskConfig,
) -> Job:
# NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals
return await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER
).run_eval(
eval_task_def=eval_task_def,
eval_task_config=eval_task_config,
task=task,
task_config=task_config,
)
@webmethod(route="/eval/evaluate_rows", method="POST")