mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-05 05:35:22 +00:00
[Evals API][10/n] API updates for EvalTaskDef + new test migration (#379)
* wip * scoring fn api * eval api * eval task * evaluate api update * pre commit * unwrap context -> config * config field doc * typo * naming fix * separate benchmark / app eval * api name * rename * wip tests * wip * datasetio test * delete unused * fixture * scoring resolve * fix scoring register * scoring test pass * score batch * scoring fix * fix eval * test eval works * remove type ignore * api refactor * add default task_eval_id for routing * add eval_id for jobs * remove type ignore * only keep 1 run_eval * fix optional * register task required * register task required * delete old tests * delete old tests * fixture return impl
This commit is contained in:
parent
8350f2df4c
commit
6192bf43a4
32 changed files with 916 additions and 389 deletions
|
@ -6,13 +6,15 @@
|
|||
from enum import Enum
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
@ -25,7 +27,7 @@ class ColumnName(Enum):
|
|||
generated_answer = "generated_answer"
|
||||
|
||||
|
||||
class MetaReferenceEvalImpl(Eval):
|
||||
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceEvalConfig,
|
||||
|
@ -43,10 +45,18 @@ class MetaReferenceEvalImpl(Eval):
|
|||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs = {}
|
||||
|
||||
self.eval_tasks = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_eval_task(self, task_def: EvalTaskDef) -> None:
|
||||
self.eval_tasks[task_def.identifier] = task_def
|
||||
|
||||
async def list_eval_tasks(self) -> List[EvalTaskDef]:
|
||||
return list(self.eval_tasks.values())
|
||||
|
||||
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
|
@ -70,21 +80,26 @@ class MetaReferenceEvalImpl(Eval):
|
|||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
|
||||
async def evaluate_batch(
|
||||
async def run_eval(
|
||||
self,
|
||||
dataset_id: str,
|
||||
candidate: EvalCandidate,
|
||||
scoring_functions: List[str],
|
||||
task_id: str,
|
||||
task_config: EvalTaskConfig,
|
||||
) -> Job:
|
||||
task_def = self.eval_tasks[task_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
candidate = task_config.eval_candidate
|
||||
scoring_functions = task_def.scoring_functions
|
||||
|
||||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
res = await self.evaluate(
|
||||
res = await self.evaluate_rows(
|
||||
task_id=task_id,
|
||||
input_rows=all_rows.rows,
|
||||
candidate=candidate,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
|
@ -93,12 +108,14 @@ class MetaReferenceEvalImpl(Eval):
|
|||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id)
|
||||
|
||||
async def evaluate(
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
candidate: EvalCandidate,
|
||||
scoring_functions: List[str],
|
||||
task_config: EvalTaskConfig,
|
||||
) -> EvaluateResponse:
|
||||
candidate = task_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
raise NotImplementedError(
|
||||
"Evaluation with generation has not been implemented for agents"
|
||||
|
@ -122,7 +139,10 @@ class MetaReferenceEvalImpl(Eval):
|
|||
}
|
||||
)
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
|
||||
chat_completion_input_str = str(
|
||||
x[ColumnName.chat_completion_input.value]
|
||||
)
|
||||
input_messages = eval(chat_completion_input_str)
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
|
@ -147,23 +167,33 @@ class MetaReferenceEvalImpl(Eval):
|
|||
for input_r, generated_r in zip(input_rows, generations)
|
||||
]
|
||||
|
||||
if task_config.type == "app" and task_config.scoring_params is not None:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
||||
for scoring_fn_id in scoring_functions
|
||||
}
|
||||
else:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
||||
}
|
||||
|
||||
score_response = await self.scoring_api.score(
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||
)
|
||||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]:
|
||||
if job_id in self.jobs:
|
||||
return JobStatus.completed
|
||||
|
||||
return None
|
||||
|
||||
async def job_cancel(self, job_id: str) -> None:
|
||||
async def job_cancel(self, task_id: str, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(job_id)
|
||||
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(task_id, job_id)
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
|
|
|
@ -74,8 +74,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
return scoring_fn_defs_list
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
|
||||
raise NotImplementedError("Register scoring function not implemented yet")
|
||||
|
||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||
|
@ -97,7 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
|
@ -106,7 +105,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
rows_in_page=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.rows, scoring_functions=scoring_functions
|
||||
input_rows=all_rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
|
@ -118,14 +118,19 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
)
|
||||
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions:
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
agg_results = await scoring_fn.aggregate(score_results)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
|
|
|
@ -36,7 +36,10 @@ class BaseScoringFn(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -50,8 +53,9 @@ class BaseScoringFn(ABC):
|
|||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> List[ScoringResultRow]:
|
||||
return [
|
||||
await self.score_row(input_row, scoring_fn_identifier)
|
||||
await self.score_row(input_row, scoring_fn_identifier, scoring_params)
|
||||
for input_row in input_rows
|
||||
]
|
||||
|
|
|
@ -35,6 +35,7 @@ class EqualityScoringFn(BaseScoringFn):
|
|||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "equality",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert (
|
||||
|
|
|
@ -28,9 +28,13 @@ llm_as_judge_8b_correctness = ScoringFnDef(
|
|||
description="Llm As Judge Scoring Function",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
context=LLMAsJudgeContext(
|
||||
params=LLMAsJudgeScoringFnParams(
|
||||
prompt_template=JUDGE_PROMPT,
|
||||
judge_model="Llama3.1-8B-Instruct",
|
||||
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"],
|
||||
judge_score_regexes=[
|
||||
r"Total rating: (\d+)",
|
||||
r"rating: (\d+)",
|
||||
r"Rating: (\d+)",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
|
|
@ -36,31 +36,37 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
|||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert (
|
||||
scoring_fn_identifier is not None
|
||||
), "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}."
|
||||
|
||||
# override params if scoring_params is provided
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
|
||||
assert (
|
||||
fn_def.context.prompt_template is not None
|
||||
fn_def.params.prompt_template is not None
|
||||
), "LLM Judge prompt_template not found."
|
||||
assert (
|
||||
fn_def.context.judge_score_regex is not None
|
||||
), "LLM Judge judge_score_regex not found."
|
||||
fn_def.params.judge_score_regexes is not None
|
||||
), "LLM Judge judge_score_regexes not found."
|
||||
|
||||
input_query = input_row["input_query"]
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
judge_input_msg = fn_def.context.prompt_template.format(
|
||||
judge_input_msg = fn_def.params.prompt_template.format(
|
||||
input_query=input_query,
|
||||
expected_answer=expected_answer,
|
||||
generated_answer=generated_answer,
|
||||
)
|
||||
|
||||
judge_response = await self.inference_api.chat_completion(
|
||||
model=fn_def.context.judge_model,
|
||||
model=fn_def.params.judge_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
@ -69,10 +75,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
|||
],
|
||||
)
|
||||
content = judge_response.completion_message.content
|
||||
rating_regexs = fn_def.context.judge_score_regex
|
||||
rating_regexes = fn_def.params.judge_score_regexes
|
||||
|
||||
judge_rating = None
|
||||
for regex in rating_regexs:
|
||||
for regex in rating_regexes:
|
||||
match = re.search(regex, content)
|
||||
if match:
|
||||
judge_rating = int(match.group(1))
|
||||
|
|
|
@ -34,6 +34,7 @@ class SubsetOfScoringFn(BaseScoringFn):
|
|||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "subset_of",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue