mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
api refactor
This commit is contained in:
parent
97dcd5704c
commit
51c20f9c29
8 changed files with 64 additions and 59 deletions
|
@ -38,14 +38,16 @@ EvalCandidate = Annotated[
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BenchmarkEvalTaskConfig(BaseModel):
|
class BenchmarkEvalTaskConfig(BaseModel):
|
||||||
|
type: Literal["benchmark"] = "benchmark"
|
||||||
eval_candidate: EvalCandidate
|
eval_candidate: EvalCandidate
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AppEvalTaskConfig(BaseModel):
|
class AppEvalTaskConfig(BaseModel):
|
||||||
|
type: Literal["app"] = "app"
|
||||||
eval_candidate: EvalCandidate
|
eval_candidate: EvalCandidate
|
||||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||||
description="Map between scoring function id and parameters",
|
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
# we could optinally add any specific dataset config here
|
# we could optinally add any specific dataset config here
|
||||||
|
@ -64,18 +66,18 @@ class EvaluateResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
@webmethod(route="/eval/run_benchmark_eval", method="POST")
|
@webmethod(route="/eval/run_benchmark", method="POST")
|
||||||
async def run_benchmark_eval(
|
async def run_benchmark(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
eval_task_config: BenchmarkEvalTaskConfig,
|
benchmark_config: BenchmarkEvalTaskConfig,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/run_eval", method="POST")
|
@webmethod(route="/eval/run_eval", method="POST")
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
eval_task_def: EvalTaskDef,
|
task: EvalTaskDef,
|
||||||
eval_task_config: EvalTaskConfig,
|
task_config: AppEvalTaskConfig,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||||
|
|
|
@ -48,8 +48,7 @@ class Scoring(Protocol):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse: ...
|
) -> ScoreBatchResponse: ...
|
||||||
|
|
||||||
|
@ -57,6 +56,5 @@ class Scoring(Protocol):
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
) -> ScoreResponse: ...
|
) -> ScoreResponse: ...
|
||||||
|
|
|
@ -76,7 +76,7 @@ class ScoringFnDef(BaseModel):
|
||||||
description="The return type of the deterministic function",
|
description="The return type of the deterministic function",
|
||||||
)
|
)
|
||||||
params: Optional[ScoringFnParams] = Field( # type: ignore
|
params: Optional[ScoringFnParams] = Field( # type: ignore
|
||||||
description="The parameters for the scoring function for benchmark eval, we could override this for app eval",
|
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
# We can optionally add information here to support packaging of code, etc.
|
# We can optionally add information here to support packaging of code, etc.
|
||||||
|
|
|
@ -216,18 +216,16 @@ class ScoringRouter(Scoring):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
res = {}
|
res = {}
|
||||||
for fn_identifier in scoring_functions:
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(
|
score_response = await self.routing_table.get_provider_impl(
|
||||||
fn_identifier
|
fn_identifier
|
||||||
).score_batch(
|
).score_batch(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions=[fn_identifier],
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
scoring_params=scoring_params,
|
|
||||||
)
|
)
|
||||||
res.update(score_response.results)
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
@ -241,18 +239,16 @@ class ScoringRouter(Scoring):
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
for fn_identifier in scoring_functions:
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(
|
score_response = await self.routing_table.get_provider_impl(
|
||||||
fn_identifier
|
fn_identifier
|
||||||
).score(
|
).score(
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions=[fn_identifier],
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
scoring_params=scoring_params,
|
|
||||||
)
|
)
|
||||||
res.update(score_response.results)
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
@ -272,24 +268,24 @@ class EvalRouter(Eval):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_benchmark_eval(
|
async def run_benchmark(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
eval_task_config: BenchmarkEvalTaskConfig,
|
benchmark_config: BenchmarkEvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
eval_task_def: EvalTaskDef,
|
task: EvalTaskDef,
|
||||||
eval_task_config: EvalTaskConfig,
|
task_config: AppEvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
# NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals
|
# NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals
|
||||||
return await self.routing_table.get_provider_impl(
|
return await self.routing_table.get_provider_impl(
|
||||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
DEFAULT_EVAL_TASK_IDENTIFIER
|
||||||
).run_eval(
|
).run_eval(
|
||||||
eval_task_def=eval_task_def,
|
task=task,
|
||||||
eval_task_config=eval_task_config,
|
task_config=task_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||||
|
|
|
@ -7,11 +7,17 @@ from enum import Enum
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from .....apis.common.job_types import Job
|
from .....apis.common.job_types import Job
|
||||||
from .....apis.eval.eval import BenchmarkEvalTaskConfig
|
from .....apis.eval.eval import (
|
||||||
|
AppEvalTaskConfig,
|
||||||
|
BenchmarkEvalTaskConfig,
|
||||||
|
Eval,
|
||||||
|
EvalTaskConfig,
|
||||||
|
EvaluateResponse,
|
||||||
|
JobStatus,
|
||||||
|
)
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
|
||||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
from llama_stack.apis.eval_tasks import EvalTaskDef
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
|
@ -88,21 +94,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_benchmark_eval(
|
async def run_benchmark(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
eval_task_config: BenchmarkEvalTaskConfig,
|
benchmark_config: BenchmarkEvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
raise NotImplementedError("Benchmark eval is not implemented yet")
|
raise NotImplementedError("Benchmark eval is not implemented yet")
|
||||||
|
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
eval_task_def: EvalTaskDef,
|
task: EvalTaskDef,
|
||||||
eval_task_config: EvalTaskConfig,
|
task_config: AppEvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
dataset_id = eval_task_def.dataset_id
|
dataset_id = task.dataset_id
|
||||||
candidate = eval_task_config.eval_candidate
|
candidate = task_config.eval_candidate
|
||||||
scoring_functions = eval_task_def.scoring_functions
|
scoring_functions = task.scoring_functions
|
||||||
|
|
||||||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
|
@ -112,7 +118,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
res = await self.evaluate_rows(
|
res = await self.evaluate_rows(
|
||||||
input_rows=all_rows.rows,
|
input_rows=all_rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
eval_task_config=eval_task_config,
|
eval_task_config=task_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: currently needs to wait for generation before returning
|
# TODO: currently needs to wait for generation before returning
|
||||||
|
@ -179,8 +185,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
for input_r, generated_r in zip(input_rows, generations)
|
for input_r, generated_r in zip(input_rows, generations)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if (
|
||||||
|
eval_task_config.type == "app"
|
||||||
|
and eval_task_config.scoring_params is not None
|
||||||
|
):
|
||||||
|
scoring_functions_dict = {
|
||||||
|
scoring_fn_id: eval_task_config.scoring_params.get(scoring_fn_id, None)
|
||||||
|
for scoring_fn_id in scoring_functions
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
scoring_functions_dict = {
|
||||||
|
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
||||||
|
}
|
||||||
|
|
||||||
score_response = await self.scoring_api.score(
|
score_response = await self.scoring_api.score(
|
||||||
input_rows=score_input_rows, scoring_functions=scoring_functions
|
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||||
|
|
|
@ -96,8 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||||
|
@ -108,7 +107,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
res = await self.score(
|
res = await self.score(
|
||||||
input_rows=all_rows.rows,
|
input_rows=all_rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
scoring_params=scoring_params,
|
|
||||||
)
|
)
|
||||||
if save_results_dataset:
|
if save_results_dataset:
|
||||||
# TODO: persist and register dataset on to server for reading
|
# TODO: persist and register dataset on to server for reading
|
||||||
|
@ -122,17 +120,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
res = {}
|
res = {}
|
||||||
for scoring_fn_id in scoring_functions:
|
for scoring_fn_id in scoring_functions.keys():
|
||||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||||
scoring_fn_params = None
|
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||||
if scoring_params is not None:
|
|
||||||
scoring_fn_params = scoring_params.get(scoring_fn_id, None)
|
|
||||||
score_results = await scoring_fn.score(
|
score_results = await scoring_fn.score(
|
||||||
input_rows, scoring_fn_id, scoring_fn_params
|
input_rows, scoring_fn_id, scoring_fn_params
|
||||||
)
|
)
|
||||||
|
|
|
@ -76,13 +76,13 @@ class Testeval:
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await eval_impl.run_eval(
|
response = await eval_impl.run_eval(
|
||||||
eval_task_def=EvalTaskDef(
|
task=EvalTaskDef(
|
||||||
# NOTE: this is needed to make the router work for all app evals
|
# NOTE: this is needed to make the router work for all app evals
|
||||||
identifier="meta-reference::app_eval",
|
identifier="meta-reference::app_eval",
|
||||||
dataset_id="test_dataset_for_eval",
|
dataset_id="test_dataset_for_eval",
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
),
|
),
|
||||||
eval_task_config=AppEvalTaskConfig(
|
task_config=AppEvalTaskConfig(
|
||||||
eval_candidate=ModelCandidate(
|
eval_candidate=ModelCandidate(
|
||||||
model="Llama3.2-3B-Instruct",
|
model="Llama3.2-3B-Instruct",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
|
|
|
@ -44,10 +44,10 @@ class TestScoring:
|
||||||
)
|
)
|
||||||
assert len(rows.rows) == 3
|
assert len(rows.rows) == 3
|
||||||
|
|
||||||
scoring_functions = [
|
scoring_functions = {
|
||||||
"meta-reference::llm_as_judge_8b_correctness",
|
"meta-reference::llm_as_judge_8b_correctness": None,
|
||||||
"meta-reference::equality",
|
"meta-reference::equality": None,
|
||||||
]
|
}
|
||||||
response = await scoring_impl.score(
|
response = await scoring_impl.score(
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
|
@ -83,7 +83,7 @@ class TestScoring:
|
||||||
)
|
)
|
||||||
assert len(rows.rows) == 3
|
assert len(rows.rows) == 3
|
||||||
|
|
||||||
params = {
|
scoring_functions = {
|
||||||
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
|
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
|
||||||
judge_model="Llama3.1-405B-Instruct",
|
judge_model="Llama3.1-405B-Instruct",
|
||||||
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
|
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
|
||||||
|
@ -91,13 +91,9 @@ class TestScoring:
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
scoring_functions = [
|
|
||||||
"meta-reference::llm_as_judge_8b_correctness",
|
|
||||||
]
|
|
||||||
response = await scoring_impl.score(
|
response = await scoring_impl.score(
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
scoring_params=params,
|
|
||||||
)
|
)
|
||||||
assert len(response.results) == len(scoring_functions)
|
assert len(response.results) == len(scoring_functions)
|
||||||
for x in scoring_functions:
|
for x in scoring_functions:
|
||||||
|
@ -108,7 +104,6 @@ class TestScoring:
|
||||||
response = await scoring_impl.score_batch(
|
response = await scoring_impl.score_batch(
|
||||||
dataset_id="test_dataset",
|
dataset_id="test_dataset",
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
scoring_params=params,
|
|
||||||
)
|
)
|
||||||
assert len(response.results) == len(scoring_functions)
|
assert len(response.results) == len(scoring_functions)
|
||||||
for x in scoring_functions:
|
for x in scoring_functions:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue