mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
register task required
This commit is contained in:
parent
7ca479f400
commit
94a56cc3f3
4 changed files with 30 additions and 22 deletions
|
@ -70,7 +70,6 @@ class Eval(Protocol):
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
task_def: EvalTaskDef,
|
|
||||||
task_config: EvalTaskConfig,
|
task_config: EvalTaskConfig,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
|
|
|
@ -267,12 +267,10 @@ class EvalRouter(Eval):
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
task_def: EvalTaskDef,
|
|
||||||
task_config: AppEvalTaskConfig,
|
task_config: AppEvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
return await self.routing_table.get_provider_impl(task_id).run_eval(
|
return await self.routing_table.get_provider_impl(task_id).run_eval(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
task_def=task_def,
|
|
||||||
task_config=task_config,
|
task_config=task_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -51,22 +51,20 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
# TODO: assume sync job, will need jobs API for async scheduling
|
# TODO: assume sync job, will need jobs API for async scheduling
|
||||||
self.jobs = {}
|
self.jobs = {}
|
||||||
|
|
||||||
|
self.eval_tasks = {}
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def register_eval_task(self, task_def: EvalTaskDef) -> None:
|
||||||
|
self.eval_tasks[task_def.identifier] = task_def
|
||||||
|
|
||||||
async def list_eval_tasks(self) -> List[EvalTaskDef]:
|
async def list_eval_tasks(self) -> List[EvalTaskDef]:
|
||||||
# NOTE: In order to be routed to this provider, the eval task def must have
|
# NOTE: In order to be routed to this provider, the eval task def must have
|
||||||
# a EvalTaskDef with identifier defined as DEFAULT_EVAL_TASK_IDENTIFIER
|
# a EvalTaskDef with identifier defined as DEFAULT_EVAL_TASK_IDENTIFIER
|
||||||
# for app eval where eval task benchmark_id is not pre-registered
|
# for app eval where eval task benchmark_id is not pre-registered
|
||||||
eval_tasks = [
|
return list(self.eval_tasks.values())
|
||||||
EvalTaskDef(
|
|
||||||
identifier=DEFAULT_EVAL_TASK_IDENTIFIER,
|
|
||||||
dataset_id="",
|
|
||||||
scoring_functions=[],
|
|
||||||
)
|
|
||||||
]
|
|
||||||
return eval_tasks
|
|
||||||
|
|
||||||
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||||
|
@ -94,9 +92,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
task_def: EvalTaskDef,
|
|
||||||
task_config: EvalTaskConfig,
|
task_config: EvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
|
task_def = self.eval_tasks[task_id]
|
||||||
dataset_id = task_def.dataset_id
|
dataset_id = task_def.dataset_id
|
||||||
candidate = task_config.eval_candidate
|
candidate = task_config.eval_candidate
|
||||||
scoring_functions = task_def.scoring_functions
|
scoring_functions = task_def.scoring_functions
|
||||||
|
|
|
@ -9,7 +9,11 @@ import pytest
|
||||||
|
|
||||||
from llama_models.llama3.api import SamplingParams
|
from llama_models.llama3.api import SamplingParams
|
||||||
|
|
||||||
from llama_stack.apis.eval.eval import AppEvalTaskConfig, EvalTaskDef, ModelCandidate
|
from llama_stack.apis.eval.eval import (
|
||||||
|
AppEvalTaskConfig,
|
||||||
|
EvalTaskDefWithProvider,
|
||||||
|
ModelCandidate,
|
||||||
|
)
|
||||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +32,7 @@ class Testeval:
|
||||||
_, eval_tasks_impl, _, _, _, _ = eval_stack
|
_, eval_tasks_impl, _, _, _, _ = eval_stack
|
||||||
response = await eval_tasks_impl.list_eval_tasks()
|
response = await eval_tasks_impl.list_eval_tasks()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
assert len(response) >= 1
|
assert len(response) == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_eval_evaluate_rows(self, eval_stack):
|
async def test_eval_evaluate_rows(self, eval_stack):
|
||||||
|
@ -48,9 +52,17 @@ class Testeval:
|
||||||
"meta-reference::llm_as_judge_8b_correctness",
|
"meta-reference::llm_as_judge_8b_correctness",
|
||||||
"meta-reference::equality",
|
"meta-reference::equality",
|
||||||
]
|
]
|
||||||
|
task_id = "meta-reference::app_eval"
|
||||||
|
task_def = EvalTaskDefWithProvider(
|
||||||
|
identifier=task_id,
|
||||||
|
dataset_id="test_dataset_for_eval",
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
provider_id="meta-reference",
|
||||||
|
)
|
||||||
|
await eval_tasks_impl.register_eval_task(task_def)
|
||||||
|
|
||||||
response = await eval_impl.evaluate_rows(
|
response = await eval_impl.evaluate_rows(
|
||||||
task_id="meta-reference::app_eval",
|
task_id=task_id,
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
task_config=AppEvalTaskConfig(
|
task_config=AppEvalTaskConfig(
|
||||||
|
@ -76,15 +88,16 @@ class Testeval:
|
||||||
"meta-reference::subset_of",
|
"meta-reference::subset_of",
|
||||||
]
|
]
|
||||||
|
|
||||||
task_id = "meta-reference::app_eval"
|
task_id = "meta-reference::app_eval-2"
|
||||||
|
task_def = EvalTaskDefWithProvider(
|
||||||
|
identifier=task_id,
|
||||||
|
dataset_id="test_dataset_for_eval",
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
provider_id="meta-reference",
|
||||||
|
)
|
||||||
|
await eval_tasks_impl.register_eval_task(task_def)
|
||||||
response = await eval_impl.run_eval(
|
response = await eval_impl.run_eval(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
task_def=EvalTaskDef(
|
|
||||||
# NOTE: this is needed to make the router work for all app evals
|
|
||||||
identifier=task_id,
|
|
||||||
dataset_id="test_dataset_for_eval",
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
),
|
|
||||||
task_config=AppEvalTaskConfig(
|
task_config=AppEvalTaskConfig(
|
||||||
eval_candidate=ModelCandidate(
|
eval_candidate=ModelCandidate(
|
||||||
model="Llama3.2-3B-Instruct",
|
model="Llama3.2-3B-Instruct",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue