fix eval task registration (#426)

* fix eval tasks

* fix eval tasks

* fix eval tests
This commit is contained in:
Xi Yan 2024-11-12 11:51:34 -05:00 committed by GitHub
parent 84c6fbbd93
commit ec4fcad5ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 16 additions and 13 deletions

View file

@ -17,6 +17,8 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
@ -36,6 +38,7 @@ RoutableObject = Union[
MemoryBank, MemoryBank,
Dataset, Dataset,
ScoringFn, ScoringFn,
EvalTask,
] ]
@ -46,6 +49,7 @@ RoutableObjectWithProvider = Annotated[
MemoryBank, MemoryBank,
Dataset, Dataset,
ScoringFn, ScoringFn,
EvalTask,
], ],
Field(discriminator="type"), Field(discriminator="type"),
] ]
@ -56,6 +60,7 @@ RoutedProtocol = Union[
Memory, Memory,
DatasetIO, DatasetIO,
Scoring, Scoring,
Eval,
] ]

View file

@ -17,8 +17,8 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"eval": "meta_reference", "eval": "meta_reference",
"scoring": "meta_reference", "scoring": "basic",
"datasetio": "meta_reference", "datasetio": "localfs",
"inference": "fireworks", "inference": "fireworks",
}, },
id="meta_reference_eval_fireworks_inference", id="meta_reference_eval_fireworks_inference",
@ -27,8 +27,8 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"eval": "meta_reference", "eval": "meta_reference",
"scoring": "meta_reference", "scoring": "basic",
"datasetio": "meta_reference", "datasetio": "localfs",
"inference": "together", "inference": "together",
}, },
id="meta_reference_eval_together_inference", id="meta_reference_eval_together_inference",
@ -37,7 +37,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"eval": "meta_reference", "eval": "meta_reference",
"scoring": "meta_reference", "scoring": "basic",
"datasetio": "huggingface", "datasetio": "huggingface",
"inference": "together", "inference": "together",
}, },

View file

@ -24,7 +24,7 @@ def eval_meta_reference() -> ProviderFixture:
providers=[ providers=[
Provider( Provider(
provider_id="meta-reference", provider_id="meta-reference",
provider_type="meta-reference", provider_type="inline::meta-reference",
config={}, config={},
) )
], ],

View file

@ -63,8 +63,7 @@ class Testeval:
assert len(rows.rows) == 3 assert len(rows.rows) == 3
scoring_functions = [ scoring_functions = [
"meta-reference::llm_as_judge_base", "basic::equality",
"meta-reference::equality",
] ]
task_id = "meta-reference::app_eval" task_id = "meta-reference::app_eval"
await eval_tasks_impl.register_eval_task( await eval_tasks_impl.register_eval_task(
@ -95,8 +94,7 @@ class Testeval:
), ),
) )
assert len(response.generations) == 3 assert len(response.generations) == 3
assert "meta-reference::equality" in response.scores assert "basic::equality" in response.scores
assert "meta-reference::llm_as_judge_base" in response.scores
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack): async def test_eval_run_eval(self, eval_stack):
@ -116,7 +114,7 @@ class Testeval:
) )
scoring_functions = [ scoring_functions = [
"meta-reference::subset_of", "basic::subset_of",
] ]
task_id = "meta-reference::app_eval-2" task_id = "meta-reference::app_eval-2"
@ -141,7 +139,7 @@ class Testeval:
assert eval_response is not None assert eval_response is not None
assert len(eval_response.generations) == 5 assert len(eval_response.generations) == 5
assert "meta-reference::subset_of" in eval_response.scores assert "basic::subset_of" in eval_response.scores
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_run_benchmark_eval(self, eval_stack): async def test_eval_run_benchmark_eval(self, eval_stack):
@ -182,7 +180,7 @@ class Testeval:
await eval_tasks_impl.register_eval_task( await eval_tasks_impl.register_eval_task(
eval_task_id="meta-reference-mmlu", eval_task_id="meta-reference-mmlu",
dataset_id="mmlu", dataset_id="mmlu",
scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"], scoring_functions=["basic::regex_parser_multiple_choice_answer"],
) )
# list benchmarks # list benchmarks