[Evals API][10/n] API updates for EvalTaskDef + new test migration (#379)

* wip

* scoring fn api

* eval api

* eval task

* evaluate api update

* pre commit

* unwrap context -> config

* config field doc

* typo

* naming fix

* separate benchmark / app eval

* api name

* rename

* wip tests

* wip

* datasetio test

* delete unused

* fixture

* scoring resolve

* fix scoring register

* scoring test pass

* score batch

* scoring fix

* fix eval

* test eval works

* remove type ignore

* api refactor

* add default task_eval_id for routing

* add eval_id for jobs

* remove type ignore

* only keep 1 run_eval

* fix optional

* register task required

* register task required

* delete old tests

* delete old tests

* fixture return impl
This commit is contained in:
Xi Yan 2024-11-07 21:24:12 -08:00 committed by GitHub
parent 8350f2df4c
commit 6192bf43a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 916 additions and 389 deletions

View file

@ -14,6 +14,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
class MemoryRouter(Memory):
@ -211,16 +212,16 @@ class ScoringRouter(Scoring):
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
res = {}
for fn_identifier in scoring_functions:
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id,
scoring_functions=[fn_identifier],
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
@ -232,17 +233,87 @@ class ScoringRouter(Scoring):
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions:
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
input_rows=input_rows,
scoring_functions=[fn_identifier],
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_eval(
self,
task_id: str,
task_config: AppEvalTaskConfig,
) -> Job:
return await self.routing_table.get_provider_impl(task_id).run_eval(
task_id=task_id,
task_config=task_config,
)
@webmethod(route="/eval/evaluate_rows", method="POST")
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
task_id=task_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
)
async def job_status(
self,
task_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(task_id).job_status(
task_id, job_id
)
async def job_cancel(
self,
task_id: str,
job_id: str,
) -> None:
await self.routing_table.get_provider_impl(task_id).job_cancel(
task_id,
job_id,
)
async def job_result(
self,
task_id: str,
job_id: str,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).job_result(
task_id,
job_id,
)