api refactor

This commit is contained in:
Xi Yan 2024-11-07 13:54:26 -08:00
parent 97dcd5704c
commit 51c20f9c29
8 changed files with 64 additions and 59 deletions

View file

@ -7,11 +7,17 @@ from enum import Enum
from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import BenchmarkEvalTaskConfig
from .....apis.eval.eval import (
AppEvalTaskConfig,
BenchmarkEvalTaskConfig,
Eval,
EvalTaskConfig,
EvaluateResponse,
JobStatus,
)
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.eval_tasks import EvalTaskDef
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
@ -88,21 +94,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def run_benchmark_eval(
async def run_benchmark(
self,
benchmark_id: str,
eval_task_config: BenchmarkEvalTaskConfig,
benchmark_config: BenchmarkEvalTaskConfig,
) -> Job:
raise NotImplementedError("Benchmark eval is not implemented yet")
async def run_eval(
self,
eval_task_def: EvalTaskDef,
eval_task_config: EvalTaskConfig,
task: EvalTaskDef,
task_config: AppEvalTaskConfig,
) -> Job:
dataset_id = eval_task_def.dataset_id
candidate = eval_task_config.eval_candidate
scoring_functions = eval_task_def.scoring_functions
dataset_id = task.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
@ -112,7 +118,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
res = await self.evaluate_rows(
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
eval_task_config=eval_task_config,
eval_task_config=task_config,
)
# TODO: currently needs to wait for generation before returning
@ -179,8 +185,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
for input_r, generated_r in zip(input_rows, generations)
]
if (
eval_task_config.type == "app"
and eval_task_config.scoring_params is not None
):
scoring_functions_dict = {
scoring_fn_id: eval_task_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)

View file

@ -96,8 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
@ -108,7 +107,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
res = await self.score(
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
scoring_params=scoring_params,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
@ -122,17 +120,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions:
for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = None
if scoring_params is not None:
scoring_fn_params = scoring_params.get(scoring_fn_id, None)
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)

View file

@ -76,13 +76,13 @@ class Testeval:
]
response = await eval_impl.run_eval(
eval_task_def=EvalTaskDef(
task=EvalTaskDef(
# NOTE: this is needed to make the router work for all app evals
identifier="meta-reference::app_eval",
dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions,
),
eval_task_config=AppEvalTaskConfig(
task_config=AppEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
sampling_params=SamplingParams(),

View file

@ -44,10 +44,10 @@ class TestScoring:
)
assert len(rows.rows) == 3
scoring_functions = [
"meta-reference::llm_as_judge_8b_correctness",
"meta-reference::equality",
]
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": None,
"meta-reference::equality": None,
}
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
@ -83,7 +83,7 @@ class TestScoring:
)
assert len(rows.rows) == 3
params = {
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
@ -91,13 +91,9 @@ class TestScoring:
)
}
scoring_functions = [
"meta-reference::llm_as_judge_8b_correctness",
]
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
scoring_params=params,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
@ -108,7 +104,6 @@ class TestScoring:
response = await scoring_impl.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
scoring_params=params,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions: