mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 06:53:47 +00:00
migrate evals to resource
This commit is contained in:
parent
b95cb5308f
commit
63b5eb929f
5 changed files with 57 additions and 40 deletions
|
@ -12,8 +12,8 @@ from llama_models.schema_utils import json_schema_type
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
||||
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
||||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
@ -67,9 +67,9 @@ class ScoringFunctionsProtocolPrivate(Protocol):
|
|||
|
||||
|
||||
class EvalTasksProtocolPrivate(Protocol):
|
||||
async def list_eval_tasks(self) -> List[EvalTaskDef]: ...
|
||||
async def list_eval_tasks(self) -> List[EvalTask]: ...
|
||||
|
||||
async def register_eval_task(self, eval_task_def: EvalTaskDef) -> None: ...
|
||||
async def register_eval_task(self, eval_task: EvalTask) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -11,7 +11,7 @@ from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatu
|
|||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
||||
from llama_stack.apis.eval_tasks import EvalTask
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
||||
|
@ -53,10 +53,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_eval_task(self, task_def: EvalTaskDef) -> None:
|
||||
async def register_eval_task(self, task_def: EvalTask) -> None:
|
||||
self.eval_tasks[task_def.identifier] = task_def
|
||||
|
||||
async def list_eval_tasks(self) -> List[EvalTaskDef]:
|
||||
async def list_eval_tasks(self) -> List[EvalTask]:
|
||||
return list(self.eval_tasks.values())
|
||||
|
||||
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
|
|
|
@ -16,7 +16,6 @@ from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider
|
|||
from llama_stack.apis.eval.eval import (
|
||||
AppEvalTaskConfig,
|
||||
BenchmarkEvalTaskConfig,
|
||||
EvalTaskDefWithProvider,
|
||||
ModelCandidate,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
|
||||
|
@ -70,13 +69,11 @@ class Testeval:
|
|||
"meta-reference::equality",
|
||||
]
|
||||
task_id = "meta-reference::app_eval"
|
||||
task_def = EvalTaskDefWithProvider(
|
||||
identifier=task_id,
|
||||
await eval_tasks_impl.register_eval_task(
|
||||
eval_task_id=task_id,
|
||||
dataset_id="test_dataset_for_eval",
|
||||
scoring_functions=scoring_functions,
|
||||
provider_id="meta-reference",
|
||||
)
|
||||
await eval_tasks_impl.register_eval_task(task_def)
|
||||
response = await eval_impl.evaluate_rows(
|
||||
task_id=task_id,
|
||||
input_rows=rows.rows,
|
||||
|
@ -125,13 +122,11 @@ class Testeval:
|
|||
]
|
||||
|
||||
task_id = "meta-reference::app_eval-2"
|
||||
task_def = EvalTaskDefWithProvider(
|
||||
identifier=task_id,
|
||||
await eval_tasks_impl.register_eval_task(
|
||||
eval_task_id=task_id,
|
||||
dataset_id="test_dataset_for_eval",
|
||||
scoring_functions=scoring_functions,
|
||||
provider_id="meta-reference",
|
||||
)
|
||||
await eval_tasks_impl.register_eval_task(task_def)
|
||||
response = await eval_impl.run_eval(
|
||||
task_id=task_id,
|
||||
task_config=AppEvalTaskConfig(
|
||||
|
@ -189,15 +184,12 @@ class Testeval:
|
|||
await datasets_impl.register_dataset(mmlu)
|
||||
|
||||
# register eval task
|
||||
meta_reference_mmlu = EvalTaskDefWithProvider(
|
||||
identifier="meta-reference-mmlu",
|
||||
await eval_tasks_impl.register_eval_task(
|
||||
eval_task_id="meta-reference-mmlu",
|
||||
dataset_id="mmlu",
|
||||
scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"],
|
||||
provider_id="",
|
||||
)
|
||||
|
||||
await eval_tasks_impl.register_eval_task(meta_reference_mmlu)
|
||||
|
||||
# list benchmarks
|
||||
response = await eval_tasks_impl.list_eval_tasks()
|
||||
assert len(response) > 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue