From 00658e02f8ba5c3b7bf4042b760619413281afa8 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 11 Dec 2024 10:36:39 -0800 Subject: [PATCH] fix eval tests model registration --- llama_stack/providers/tests/eval/conftest.py | 7 ++++ llama_stack/providers/tests/eval/fixtures.py | 11 +++++-- llama_stack/providers/tests/eval/test_eval.py | 32 ++++++------------- 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py index b310439ce..1bb49d41f 100644 --- a/llama_stack/providers/tests/eval/conftest.py +++ b/llama_stack/providers/tests/eval/conftest.py @@ -80,6 +80,13 @@ def pytest_addoption(parser): help="Specify the inference model to use for testing", ) + parser.addoption( + "--judge-model", + action="store", + default="meta-llama/Llama-3.1-8B-Instruct", + help="Specify the judge model to use for testing", + ) + def pytest_generate_tests(metafunc): if "eval_stack" in metafunc.fixturenames: diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py index 50dc9c16e..eba7c48a6 100644 --- a/llama_stack/providers/tests/eval/fixtures.py +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -7,7 +7,7 @@ import pytest import pytest_asyncio -from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.distribution.datatypes import Api, ModelInput, Provider from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -35,7 +35,7 @@ EVAL_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") -async def eval_stack(request): +async def eval_stack(request, inference_model, judge_model): fixture_dict = request.param providers = {} @@ -66,6 +66,13 @@ async def eval_stack(request): ], providers, provider_data, + models=[ + ModelInput(model_id=model) + for model in [ + inference_model, + judge_model, + ] + ], ) return test_stack.impls diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 168745550..38da74128 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -38,7 +38,7 @@ class Testeval: assert isinstance(response, list) @pytest.mark.asyncio - async def test_eval_evaluate_rows(self, eval_stack): + async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model): eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl, models_impl = ( eval_stack[Api.eval], eval_stack[Api.eval_tasks], @@ -46,11 +46,7 @@ class Testeval: eval_stack[Api.datasets], eval_stack[Api.models], ) - for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: - await models_impl.register_model( - model_id=model_id, - provider_id="", - ) + await register_dataset( datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" ) @@ -77,12 +73,12 @@ class Testeval: scoring_functions=scoring_functions, task_config=AppEvalTaskConfig( eval_candidate=ModelCandidate( - model="Llama3.2-3B-Instruct", + model=inference_model, sampling_params=SamplingParams(), ), scoring_params={ "meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams( - judge_model="Llama3.1-8B-Instruct", + judge_model=judge_model, prompt_template=JUDGE_PROMPT, judge_score_regexes=[ r"Total rating: (\d+)", @@ -97,18 +93,14 @@ class Testeval: assert "basic::equality" in response.scores @pytest.mark.asyncio - async def test_eval_run_eval(self, eval_stack): + async def test_eval_run_eval(self, eval_stack, inference_model, judge_model): eval_impl, eval_tasks_impl, datasets_impl, models_impl = ( eval_stack[Api.eval], eval_stack[Api.eval_tasks], eval_stack[Api.datasets], eval_stack[Api.models], ) - for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: - await models_impl.register_model( - model_id=model_id, - provider_id="", - ) + await register_dataset( datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" ) @@ -127,7 +119,7 @@ class Testeval: task_id=task_id, task_config=AppEvalTaskConfig( eval_candidate=ModelCandidate( - model="Llama3.2-3B-Instruct", + model=inference_model, sampling_params=SamplingParams(), ), ), @@ -142,18 +134,14 @@ class Testeval: assert "basic::subset_of" in eval_response.scores @pytest.mark.asyncio - async def test_eval_run_benchmark_eval(self, eval_stack): + async def test_eval_run_benchmark_eval(self, eval_stack, inference_model): eval_impl, eval_tasks_impl, datasets_impl, models_impl = ( eval_stack[Api.eval], eval_stack[Api.eval_tasks], eval_stack[Api.datasets], eval_stack[Api.models], ) - for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: - await models_impl.register_model( - model_id=model_id, - provider_id="", - ) + response = await datasets_impl.list_datasets() assert len(response) > 0 if response[0].provider_id != "huggingface": @@ -192,7 +180,7 @@ class Testeval: task_id=benchmark_id, task_config=BenchmarkEvalTaskConfig( eval_candidate=ModelCandidate( - model="Llama3.2-3B-Instruct", + model=inference_model, sampling_params=SamplingParams(), ), num_examples=3,