[Evals API][10/n] API updates for EvalTaskDef + new test migration (#379)

* wip

* scoring fn api

* eval api

* eval task

* evaluate api update

* pre commit

* unwrap context -> config

* config field doc

* typo

* naming fix

* separate benchmark / app eval

* api name

* rename

* wip tests

* wip

* datasetio test

* delete unused

* fixture

* scoring resolve

* fix scoring register

* scoring test pass

* score batch

* scoring fix

* fix eval

* test eval works

* remove type ignore

* api refactor

* add default task_eval_id for routing

* add eval_id for jobs

* remove type ignore

* only keep 1 run_eval

* fix optional

* register task required

* register task required

* delete old tests

* delete old tests

* fixture return impl
This commit is contained in:
Xi Yan 2024-11-07 21:24:12 -08:00 committed by GitHub
parent 8350f2df4c
commit 6192bf43a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 916 additions and 389 deletions

View file

@ -6,13 +6,15 @@
from enum import Enum
from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
from llama_stack.apis.eval_tasks import EvalTaskDef
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from .config import MetaReferenceEvalConfig
@ -25,7 +27,7 @@ class ColumnName(Enum):
generated_answer = "generated_answer"
class MetaReferenceEvalImpl(Eval):
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
def __init__(
self,
config: MetaReferenceEvalConfig,
@ -43,10 +45,18 @@ class MetaReferenceEvalImpl(Eval):
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
self.eval_tasks = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def register_eval_task(self, task_def: EvalTaskDef) -> None:
self.eval_tasks[task_def.identifier] = task_def
async def list_eval_tasks(self) -> List[EvalTaskDef]:
return list(self.eval_tasks.values())
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
@ -70,21 +80,26 @@ class MetaReferenceEvalImpl(Eval):
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def evaluate_batch(
async def run_eval(
self,
dataset_id: str,
candidate: EvalCandidate,
scoring_functions: List[str],
task_id: str,
task_config: EvalTaskConfig,
) -> Job:
task_def = self.eval_tasks[task_id]
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.evaluate(
res = await self.evaluate_rows(
task_id=task_id,
input_rows=all_rows.rows,
candidate=candidate,
scoring_functions=scoring_functions,
task_config=task_config,
)
# TODO: currently needs to wait for generation before returning
@ -93,12 +108,14 @@ class MetaReferenceEvalImpl(Eval):
self.jobs[job_id] = res
return Job(job_id=job_id)
async def evaluate(
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
candidate: EvalCandidate,
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
candidate = task_config.eval_candidate
if candidate.type == "agent":
raise NotImplementedError(
"Evaluation with generation has not been implemented for agents"
@ -122,7 +139,10 @@ class MetaReferenceEvalImpl(Eval):
}
)
elif ColumnName.chat_completion_input.value in x:
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
chat_completion_input_str = str(
x[ColumnName.chat_completion_input.value]
)
input_messages = eval(chat_completion_input_str)
input_messages = [UserMessage(**x) for x in input_messages]
messages = []
if candidate.system_message:
@ -147,23 +167,33 @@ class MetaReferenceEvalImpl(Eval):
for input_r, generated_r in zip(input_rows, generations)
]
if task_config.type == "app" and task_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, job_id: str) -> Optional[JobStatus]:
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]:
if job_id in self.jobs:
return JobStatus.completed
return None
async def job_cancel(self, job_id: str) -> None:
async def job_cancel(self, task_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, job_id: str) -> EvaluateResponse:
status = await self.job_status(job_id)
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(task_id, job_id)
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")