From 3b5a33d921f28b6c387e35fa593da417fdb841c3 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 11 Dec 2024 10:31:05 -0800 Subject: [PATCH] parameterize judge_model --- .../providers/tests/scoring/conftest.py | 15 ++++++++++++++ .../providers/tests/scoring/fixtures.py | 12 ++++++++--- .../providers/tests/scoring/test_scoring.py | 20 ++++--------------- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py index 327acab84..dc4979dd7 100644 --- a/llama_stack/providers/tests/scoring/conftest.py +++ b/llama_stack/providers/tests/scoring/conftest.py @@ -47,6 +47,7 @@ def pytest_configure(config): for fixture_name in [ "basic_scoring_together_inference", "braintrust_scoring_together_inference", + "llm_as_judge_scoring_together_inference", ]: config.addinivalue_line( "markers", @@ -61,9 +62,23 @@ def pytest_addoption(parser): default="meta-llama/Llama-3.2-3B-Instruct", help="Specify the inference model to use for testing", ) + parser.addoption( + "--judge-model", + action="store", + default="meta-llama/Llama-3.1-8B-Instruct", + help="Specify the judge model to use for testing", + ) def pytest_generate_tests(metafunc): + judge_model = metafunc.config.getoption("--judge-model") + if "judge_model" in metafunc.fixturenames: + metafunc.parametrize( + "judge_model", + [pytest.param(judge_model, id="")], + indirect=True, + ) + if "scoring_stack" in metafunc.fixturenames: available_fixtures = { "scoring": SCORING_FIXTURES, diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index a9f088e07..2cf32b1e2 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -21,6 +21,13 @@ def scoring_remote() -> ProviderFixture: return remote_stack_fixture() +@pytest.fixture(scope="session") +def judge_model(request): + if hasattr(request, "param"): + return request.param + return request.config.getoption("--judge-model", None) + + @pytest.fixture(scope="session") def scoring_basic() -> ProviderFixture: return ProviderFixture( @@ -66,7 +73,7 @@ SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"] @pytest_asyncio.fixture(scope="session") -async def scoring_stack(request, inference_model): +async def scoring_stack(request, inference_model, judge_model): fixture_dict = request.param providers = {} @@ -85,8 +92,7 @@ async def scoring_stack(request, inference_model): ModelInput(model_id=model) for model in [ inference_model, - "Llama3.1-405B-Instruct", - "Llama3.1-8B-Instruct", + judge_model, ] ], ) diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 846d30cbb..dce069df0 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -64,12 +64,6 @@ class TestScoring: response = await datasets_impl.list_datasets() assert len(response) == 1 - for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]: - await models_impl.register_model( - model_id=model_id, - provider_id="", - ) - # scoring individual rows rows = await datasetio_impl.get_rows_paginated( dataset_id="test_dataset", @@ -103,7 +97,7 @@ class TestScoring: @pytest.mark.asyncio async def test_scoring_score_with_params_llm_as_judge( - self, scoring_stack, sample_judge_prompt_template + self, scoring_stack, sample_judge_prompt_template, judge_model ): ( scoring_impl, @@ -122,12 +116,6 @@ class TestScoring: response = await datasets_impl.list_datasets() assert len(response) == 1 - for model_id in ["Llama3.1-405B-Instruct"]: - await models_impl.register_model( - model_id=model_id, - provider_id="", - ) - scoring_fns_list = await scoring_functions_impl.list_scoring_functions() provider_id = scoring_fns_list[0].provider_id if provider_id == "braintrust" or provider_id == "basic": @@ -142,7 +130,7 @@ class TestScoring: scoring_functions = { "llm-as-judge::base": LLMAsJudgeScoringFnParams( - judge_model="Llama3.1-405B-Instruct", + judge_model=judge_model, prompt_template=sample_judge_prompt_template, judge_score_regexes=[r"Score: (\d+)"], aggregation_functions=[AggregationFunctionType.categorical_count], @@ -170,7 +158,7 @@ class TestScoring: @pytest.mark.asyncio async def test_scoring_score_with_aggregation_functions( - self, scoring_stack, sample_judge_prompt_template + self, scoring_stack, sample_judge_prompt_template, judge_model ): ( scoring_impl, @@ -204,7 +192,7 @@ class TestScoring: if x.provider_id == "llm-as-judge": aggr_fns = [AggregationFunctionType.categorical_count] scoring_functions[x.identifier] = LLMAsJudgeScoringFnParams( - judge_model="Llama3.1-405B-Instruct", + judge_model=judge_model, prompt_template=sample_judge_prompt_template, judge_score_regexes=[r"Score: (\d+)"], aggregation_functions=aggr_fns,