api refactor

This commit is contained in:
Xi Yan 2024-11-07 13:54:26 -08:00
parent 97dcd5704c
commit 51c20f9c29
8 changed files with 64 additions and 59 deletions

View file

@ -38,14 +38,16 @@ EvalCandidate = Annotated[
@json_schema_type @json_schema_type
class BenchmarkEvalTaskConfig(BaseModel): class BenchmarkEvalTaskConfig(BaseModel):
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate eval_candidate: EvalCandidate
@json_schema_type @json_schema_type
class AppEvalTaskConfig(BaseModel): class AppEvalTaskConfig(BaseModel):
type: Literal["app"] = "app"
eval_candidate: EvalCandidate eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field( scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters", description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict, default_factory=dict,
) )
# we could optinally add any specific dataset config here # we could optinally add any specific dataset config here
@ -64,18 +66,18 @@ class EvaluateResponse(BaseModel):
class Eval(Protocol): class Eval(Protocol):
@webmethod(route="/eval/run_benchmark_eval", method="POST") @webmethod(route="/eval/run_benchmark", method="POST")
async def run_benchmark_eval( async def run_benchmark(
self, self,
benchmark_id: str, benchmark_id: str,
eval_task_config: BenchmarkEvalTaskConfig, benchmark_config: BenchmarkEvalTaskConfig,
) -> Job: ... ) -> Job: ...
@webmethod(route="/eval/run_eval", method="POST") @webmethod(route="/eval/run_eval", method="POST")
async def run_eval( async def run_eval(
self, self,
eval_task_def: EvalTaskDef, task: EvalTaskDef,
eval_task_config: EvalTaskConfig, task_config: AppEvalTaskConfig,
) -> Job: ... ) -> Job: ...
@webmethod(route="/eval/evaluate_rows", method="POST") @webmethod(route="/eval/evaluate_rows", method="POST")

View file

@ -48,8 +48,7 @@ class Scoring(Protocol):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse: ...
@ -57,6 +56,5 @@ class Scoring(Protocol):
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: List[str], scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse: ... ) -> ScoreResponse: ...

View file

@ -76,7 +76,7 @@ class ScoringFnDef(BaseModel):
description="The return type of the deterministic function", description="The return type of the deterministic function",
) )
params: Optional[ScoringFnParams] = Field( # type: ignore params: Optional[ScoringFnParams] = Field( # type: ignore
description="The parameters for the scoring function for benchmark eval, we could override this for app eval", description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None, default=None,
) )
# We can optionally add information here to support packaging of code, etc. # We can optionally add information here to support packaging of code, etc.

View file

@ -216,18 +216,16 @@ class ScoringRouter(Scoring):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
res = {} res = {}
for fn_identifier in scoring_functions: for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(
fn_identifier fn_identifier
).score_batch( ).score_batch(
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions=[fn_identifier], scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
scoring_params=scoring_params,
) )
res.update(score_response.results) res.update(score_response.results)
@ -241,18 +239,16 @@ class ScoringRouter(Scoring):
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: List[str], scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions: for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(
fn_identifier fn_identifier
).score( ).score(
input_rows=input_rows, input_rows=input_rows,
scoring_functions=[fn_identifier], scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
scoring_params=scoring_params,
) )
res.update(score_response.results) res.update(score_response.results)
@ -272,24 +268,24 @@ class EvalRouter(Eval):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def run_benchmark_eval( async def run_benchmark(
self, self,
benchmark_id: str, benchmark_id: str,
eval_task_config: BenchmarkEvalTaskConfig, benchmark_config: BenchmarkEvalTaskConfig,
) -> Job: ) -> Job:
pass pass
async def run_eval( async def run_eval(
self, self,
eval_task_def: EvalTaskDef, task: EvalTaskDef,
eval_task_config: EvalTaskConfig, task_config: AppEvalTaskConfig,
) -> Job: ) -> Job:
# NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals # NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals
return await self.routing_table.get_provider_impl( return await self.routing_table.get_provider_impl(
DEFAULT_EVAL_TASK_IDENTIFIER DEFAULT_EVAL_TASK_IDENTIFIER
).run_eval( ).run_eval(
eval_task_def=eval_task_def, task=task,
eval_task_config=eval_task_config, task_config=task_config,
) )
@webmethod(route="/eval/evaluate_rows", method="POST") @webmethod(route="/eval/evaluate_rows", method="POST")

View file

@ -7,11 +7,17 @@ from enum import Enum
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job from .....apis.common.job_types import Job
from .....apis.eval.eval import BenchmarkEvalTaskConfig from .....apis.eval.eval import (
AppEvalTaskConfig,
BenchmarkEvalTaskConfig,
Eval,
EvalTaskConfig,
EvaluateResponse,
JobStatus,
)
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.eval_tasks import EvalTaskDef
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
@ -88,21 +94,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
) )
async def run_benchmark_eval( async def run_benchmark(
self, self,
benchmark_id: str, benchmark_id: str,
eval_task_config: BenchmarkEvalTaskConfig, benchmark_config: BenchmarkEvalTaskConfig,
) -> Job: ) -> Job:
raise NotImplementedError("Benchmark eval is not implemented yet") raise NotImplementedError("Benchmark eval is not implemented yet")
async def run_eval( async def run_eval(
self, self,
eval_task_def: EvalTaskDef, task: EvalTaskDef,
eval_task_config: EvalTaskConfig, task_config: AppEvalTaskConfig,
) -> Job: ) -> Job:
dataset_id = eval_task_def.dataset_id dataset_id = task.dataset_id
candidate = eval_task_config.eval_candidate candidate = task_config.eval_candidate
scoring_functions = eval_task_def.scoring_functions scoring_functions = task.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
@ -112,7 +118,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
res = await self.evaluate_rows( res = await self.evaluate_rows(
input_rows=all_rows.rows, input_rows=all_rows.rows,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
eval_task_config=eval_task_config, eval_task_config=task_config,
) )
# TODO: currently needs to wait for generation before returning # TODO: currently needs to wait for generation before returning
@ -179,8 +185,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
for input_r, generated_r in zip(input_rows, generations) for input_r, generated_r in zip(input_rows, generations)
] ]
if (
eval_task_config.type == "app"
and eval_task_config.scoring_params is not None
):
scoring_functions_dict = {
scoring_fn_id: eval_task_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score( score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions input_rows=score_input_rows, scoring_functions=scoring_functions_dict
) )
return EvaluateResponse(generations=generations, scores=score_response.results) return EvaluateResponse(generations=generations, scores=score_response.results)

View file

@ -96,8 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
@ -108,7 +107,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
res = await self.score( res = await self.score(
input_rows=all_rows.rows, input_rows=all_rows.rows,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
scoring_params=scoring_params,
) )
if save_results_dataset: if save_results_dataset:
# TODO: persist and register dataset on to server for reading # TODO: persist and register dataset on to server for reading
@ -122,17 +120,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: List[str], scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
res = {} res = {}
for scoring_fn_id in scoring_functions: for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls: if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = None scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
if scoring_params is not None:
scoring_fn_params = scoring_params.get(scoring_fn_id, None)
score_results = await scoring_fn.score( score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params input_rows, scoring_fn_id, scoring_fn_params
) )

View file

@ -76,13 +76,13 @@ class Testeval:
] ]
response = await eval_impl.run_eval( response = await eval_impl.run_eval(
eval_task_def=EvalTaskDef( task=EvalTaskDef(
# NOTE: this is needed to make the router work for all app evals # NOTE: this is needed to make the router work for all app evals
identifier="meta-reference::app_eval", identifier="meta-reference::app_eval",
dataset_id="test_dataset_for_eval", dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
), ),
eval_task_config=AppEvalTaskConfig( task_config=AppEvalTaskConfig(
eval_candidate=ModelCandidate( eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct", model="Llama3.2-3B-Instruct",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),

View file

@ -44,10 +44,10 @@ class TestScoring:
) )
assert len(rows.rows) == 3 assert len(rows.rows) == 3
scoring_functions = [ scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness", "meta-reference::llm_as_judge_8b_correctness": None,
"meta-reference::equality", "meta-reference::equality": None,
] }
response = await scoring_impl.score( response = await scoring_impl.score(
input_rows=rows.rows, input_rows=rows.rows,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
@ -83,7 +83,7 @@ class TestScoring:
) )
assert len(rows.rows) == 3 assert len(rows.rows) == 3
params = { scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams( "meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct", judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.", prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
@ -91,13 +91,9 @@ class TestScoring:
) )
} }
scoring_functions = [
"meta-reference::llm_as_judge_8b_correctness",
]
response = await scoring_impl.score( response = await scoring_impl.score(
input_rows=rows.rows, input_rows=rows.rows,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
scoring_params=params,
) )
assert len(response.results) == len(scoring_functions) assert len(response.results) == len(scoring_functions)
for x in scoring_functions: for x in scoring_functions:
@ -108,7 +104,6 @@ class TestScoring:
response = await scoring_impl.score_batch( response = await scoring_impl.score_batch(
dataset_id="test_dataset", dataset_id="test_dataset",
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
scoring_params=params,
) )
assert len(response.results) == len(scoring_functions) assert len(response.results) == len(scoring_functions)
for x in scoring_functions: for x in scoring_functions: