mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 23:03:49 +00:00
api refactor
This commit is contained in:
parent
97dcd5704c
commit
51c20f9c29
8 changed files with 64 additions and 59 deletions
|
@ -216,18 +216,16 @@ class ScoringRouter(Scoring):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions:
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=[fn_identifier],
|
||||
scoring_params=scoring_params,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
|
@ -241,18 +239,16 @@ class ScoringRouter(Scoring):
|
|||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions:
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions=[fn_identifier],
|
||||
scoring_params=scoring_params,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
|
@ -272,24 +268,24 @@ class EvalRouter(Eval):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_benchmark_eval(
|
||||
async def run_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
eval_task_config: BenchmarkEvalTaskConfig,
|
||||
benchmark_config: BenchmarkEvalTaskConfig,
|
||||
) -> Job:
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
eval_task_def: EvalTaskDef,
|
||||
eval_task_config: EvalTaskConfig,
|
||||
task: EvalTaskDef,
|
||||
task_config: AppEvalTaskConfig,
|
||||
) -> Job:
|
||||
# NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals
|
||||
return await self.routing_table.get_provider_impl(
|
||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
||||
).run_eval(
|
||||
eval_task_def=eval_task_def,
|
||||
eval_task_config=eval_task_config,
|
||||
task=task,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue