fix eval task registration (#426)

* fix eval tasks

* fix eval tasks

* fix eval tests
This commit is contained in:
Xi Yan 2024-11-12 11:51:34 -05:00 committed by GitHub
parent 84c6fbbd93
commit ec4fcad5ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 16 additions and 13 deletions

View file

@ -17,6 +17,8 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
@ -36,6 +38,7 @@ RoutableObject = Union[
MemoryBank,
Dataset,
ScoringFn,
EvalTask,
]
@ -46,6 +49,7 @@ RoutableObjectWithProvider = Annotated[
MemoryBank,
Dataset,
ScoringFn,
EvalTask,
],
Field(discriminator="type"),
]
@ -56,6 +60,7 @@ RoutedProtocol = Union[
Memory,
DatasetIO,
Scoring,
Eval,
]

View file

@ -17,8 +17,8 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"eval": "meta_reference",
"scoring": "meta_reference",
"datasetio": "meta_reference",
"scoring": "basic",
"datasetio": "localfs",
"inference": "fireworks",
},
id="meta_reference_eval_fireworks_inference",
@ -27,8 +27,8 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"eval": "meta_reference",
"scoring": "meta_reference",
"datasetio": "meta_reference",
"scoring": "basic",
"datasetio": "localfs",
"inference": "together",
},
id="meta_reference_eval_together_inference",
@ -37,7 +37,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"eval": "meta_reference",
"scoring": "meta_reference",
"scoring": "basic",
"datasetio": "huggingface",
"inference": "together",
},

View file

@ -24,7 +24,7 @@ def eval_meta_reference() -> ProviderFixture:
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
provider_type="inline::meta-reference",
config={},
)
],

View file

@ -63,8 +63,7 @@ class Testeval:
assert len(rows.rows) == 3
scoring_functions = [
"meta-reference::llm_as_judge_base",
"meta-reference::equality",
"basic::equality",
]
task_id = "meta-reference::app_eval"
await eval_tasks_impl.register_eval_task(
@ -95,8 +94,7 @@ class Testeval:
),
)
assert len(response.generations) == 3
assert "meta-reference::equality" in response.scores
assert "meta-reference::llm_as_judge_base" in response.scores
assert "basic::equality" in response.scores
@pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack):
@ -116,7 +114,7 @@ class Testeval:
)
scoring_functions = [
"meta-reference::subset_of",
"basic::subset_of",
]
task_id = "meta-reference::app_eval-2"
@ -141,7 +139,7 @@ class Testeval:
assert eval_response is not None
assert len(eval_response.generations) == 5
assert "meta-reference::subset_of" in eval_response.scores
assert "basic::subset_of" in eval_response.scores
@pytest.mark.asyncio
async def test_eval_run_benchmark_eval(self, eval_stack):
@ -182,7 +180,7 @@ class Testeval:
await eval_tasks_impl.register_eval_task(
eval_task_id="meta-reference-mmlu",
dataset_id="mmlu",
scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"],
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
)
# list benchmarks