From 41487e6ed143a3acb72fe331da41df4ad5cdb2cb Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 11 Dec 2024 10:47:37 -0800 Subject: [PATCH] refactor scoring/eval pytests (#607) # What does this PR do? - remove model registration & parameterize model in scoring/eval pytests ## Test Plan ``` pytest -v -s -m meta_reference_eval_together_inference eval/test_eval.py pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio eval/test_eval.py ``` ``` pytest -v -s -m llm_as_judge_scoring_together_inference scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct pytest -v -s -m basic_scoring_together_inference scoring/test_scoring.py ``` image ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- llama_stack/providers/tests/eval/conftest.py | 7 ++++ llama_stack/providers/tests/eval/fixtures.py | 11 +++++-- llama_stack/providers/tests/eval/test_eval.py | 32 ++++++------------- .../providers/tests/scoring/conftest.py | 15 +++++++++ .../providers/tests/scoring/fixtures.py | 12 +++++-- .../providers/tests/scoring/test_scoring.py | 20 +++--------- 6 files changed, 54 insertions(+), 43 deletions(-) diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py index b310439ce..1bb49d41f 100644 --- a/llama_stack/providers/tests/eval/conftest.py +++ b/llama_stack/providers/tests/eval/conftest.py @@ -80,6 +80,13 @@ def pytest_addoption(parser): help="Specify the inference model to use for testing", ) + parser.addoption( + "--judge-model", + action="store", + default="meta-llama/Llama-3.1-8B-Instruct", + help="Specify the judge model to use for testing", + ) + def pytest_generate_tests(metafunc): if "eval_stack" in metafunc.fixturenames: diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py index 50dc9c16e..eba7c48a6 100644 --- a/llama_stack/providers/tests/eval/fixtures.py +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -7,7 +7,7 @@ import pytest import pytest_asyncio -from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.distribution.datatypes import Api, ModelInput, Provider from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -35,7 +35,7 @@ EVAL_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") -async def eval_stack(request): +async def eval_stack(request, inference_model, judge_model): fixture_dict = request.param providers = {} @@ -66,6 +66,13 @@ async def eval_stack(request): ], providers, provider_data, + models=[ + ModelInput(model_id=model) + for model in [ + inference_model, + judge_model, + ] + ], ) return test_stack.impls diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 168745550..38da74128 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -38,7 +38,7 @@ class Testeval: assert isinstance(response, list) @pytest.mark.asyncio - async def test_eval_evaluate_rows(self, eval_stack): + async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model): eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl, models_impl = ( eval_stack[Api.eval], eval_stack[Api.eval_tasks], @@ -46,11 +46,7 @@ class Testeval: eval_stack[Api.datasets], eval_stack[Api.models], ) - for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: - await models_impl.register_model( - model_id=model_id, - provider_id="", - ) + await register_dataset( datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" ) @@ -77,12 +73,12 @@ class Testeval: scoring_functions=scoring_functions, task_config=AppEvalTaskConfig( eval_candidate=ModelCandidate( - model="Llama3.2-3B-Instruct", + model=inference_model, sampling_params=SamplingParams(), ), scoring_params={ "meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams( - judge_model="Llama3.1-8B-Instruct", + judge_model=judge_model, prompt_template=JUDGE_PROMPT, judge_score_regexes=[ r"Total rating: (\d+)", @@ -97,18 +93,14 @@ class Testeval: assert "basic::equality" in response.scores @pytest.mark.asyncio - async def test_eval_run_eval(self, eval_stack): + async def test_eval_run_eval(self, eval_stack, inference_model, judge_model): eval_impl, eval_tasks_impl, datasets_impl, models_impl = ( eval_stack[Api.eval], eval_stack[Api.eval_tasks], eval_stack[Api.datasets], eval_stack[Api.models], ) - for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: - await models_impl.register_model( - model_id=model_id, - provider_id="", - ) + await register_dataset( datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" ) @@ -127,7 +119,7 @@ class Testeval: task_id=task_id, task_config=AppEvalTaskConfig( eval_candidate=ModelCandidate( - model="Llama3.2-3B-Instruct", + model=inference_model, sampling_params=SamplingParams(), ), ), @@ -142,18 +134,14 @@ class Testeval: assert "basic::subset_of" in eval_response.scores @pytest.mark.asyncio - async def test_eval_run_benchmark_eval(self, eval_stack): + async def test_eval_run_benchmark_eval(self, eval_stack, inference_model): eval_impl, eval_tasks_impl, datasets_impl, models_impl = ( eval_stack[Api.eval], eval_stack[Api.eval_tasks], eval_stack[Api.datasets], eval_stack[Api.models], ) - for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: - await models_impl.register_model( - model_id=model_id, - provider_id="", - ) + response = await datasets_impl.list_datasets() assert len(response) > 0 if response[0].provider_id != "huggingface": @@ -192,7 +180,7 @@ class Testeval: task_id=benchmark_id, task_config=BenchmarkEvalTaskConfig( eval_candidate=ModelCandidate( - model="Llama3.2-3B-Instruct", + model=inference_model, sampling_params=SamplingParams(), ), num_examples=3, diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py index 327acab84..dc4979dd7 100644 --- a/llama_stack/providers/tests/scoring/conftest.py +++ b/llama_stack/providers/tests/scoring/conftest.py @@ -47,6 +47,7 @@ def pytest_configure(config): for fixture_name in [ "basic_scoring_together_inference", "braintrust_scoring_together_inference", + "llm_as_judge_scoring_together_inference", ]: config.addinivalue_line( "markers", @@ -61,9 +62,23 @@ def pytest_addoption(parser): default="meta-llama/Llama-3.2-3B-Instruct", help="Specify the inference model to use for testing", ) + parser.addoption( + "--judge-model", + action="store", + default="meta-llama/Llama-3.1-8B-Instruct", + help="Specify the judge model to use for testing", + ) def pytest_generate_tests(metafunc): + judge_model = metafunc.config.getoption("--judge-model") + if "judge_model" in metafunc.fixturenames: + metafunc.parametrize( + "judge_model", + [pytest.param(judge_model, id="")], + indirect=True, + ) + if "scoring_stack" in metafunc.fixturenames: available_fixtures = { "scoring": SCORING_FIXTURES, diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index a9f088e07..2cf32b1e2 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -21,6 +21,13 @@ def scoring_remote() -> ProviderFixture: return remote_stack_fixture() +@pytest.fixture(scope="session") +def judge_model(request): + if hasattr(request, "param"): + return request.param + return request.config.getoption("--judge-model", None) + + @pytest.fixture(scope="session") def scoring_basic() -> ProviderFixture: return ProviderFixture( @@ -66,7 +73,7 @@ SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"] @pytest_asyncio.fixture(scope="session") -async def scoring_stack(request, inference_model): +async def scoring_stack(request, inference_model, judge_model): fixture_dict = request.param providers = {} @@ -85,8 +92,7 @@ async def scoring_stack(request, inference_model): ModelInput(model_id=model) for model in [ inference_model, - "Llama3.1-405B-Instruct", - "Llama3.1-8B-Instruct", + judge_model, ] ], ) diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 846d30cbb..dce069df0 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -64,12 +64,6 @@ class TestScoring: response = await datasets_impl.list_datasets() assert len(response) == 1 - for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: - await models_impl.register_model( - model_id=model_id, - provider_id="", - ) - # scoring individual rows rows = await datasetio_impl.get_rows_paginated( dataset_id="test_dataset", @@ -103,7 +97,7 @@ class TestScoring: @pytest.mark.asyncio async def test_scoring_score_with_params_llm_as_judge( - self, scoring_stack, sample_judge_prompt_template + self, scoring_stack, sample_judge_prompt_template, judge_model ): ( scoring_impl, @@ -122,12 +116,6 @@ class TestScoring: response = await datasets_impl.list_datasets() assert len(response) == 1 - for model_id in ["Llama3.1-405B-Instruct"]: - await models_impl.register_model( - model_id=model_id, - provider_id="", - ) - scoring_fns_list = await scoring_functions_impl.list_scoring_functions() provider_id = scoring_fns_list[0].provider_id if provider_id == "braintrust" or provider_id == "basic": @@ -142,7 +130,7 @@ class TestScoring: scoring_functions = { "llm-as-judge::base": LLMAsJudgeScoringFnParams( - judge_model="Llama3.1-405B-Instruct", + judge_model=judge_model, prompt_template=sample_judge_prompt_template, judge_score_regexes=[r"Score: (\d+)"], aggregation_functions=[AggregationFunctionType.categorical_count], @@ -170,7 +158,7 @@ class TestScoring: @pytest.mark.asyncio async def test_scoring_score_with_aggregation_functions( - self, scoring_stack, sample_judge_prompt_template + self, scoring_stack, sample_judge_prompt_template, judge_model ): ( scoring_impl, @@ -204,7 +192,7 @@ class TestScoring: if x.provider_id == "llm-as-judge": aggr_fns = [AggregationFunctionType.categorical_count] scoring_functions[x.identifier] = LLMAsJudgeScoringFnParams( - judge_model="Llama3.1-405B-Instruct", + judge_model=judge_model, prompt_template=sample_judge_prompt_template, judge_score_regexes=[r"Score: (\d+)"], aggregation_functions=aggr_fns,