diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py index 20a67edc7..60ce74477 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py @@ -28,7 +28,7 @@ llm_as_judge_8b_correctness = ScoringFnDef( description="Llm As Judge Scoring Function", parameters=[], return_type=NumberType(), - context=LLMAsJudgeContext( + context=LLMAsJudgeScoringFnParams( prompt_template=JUDGE_PROMPT, judge_model="Llama3.1-8B-Instruct", judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"], diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 9b27ed94a..459b58f22 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -150,4 +150,5 @@ pytest_plugins = [ "llama_stack.providers.tests.memory.fixtures", "llama_stack.providers.tests.agents.fixtures", "llama_stack.providers.tests.datasetio.fixtures", + "llama_stack.providers.tests.scoring.fixtures", ] diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 9ea15c9b7..c02794c50 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -16,7 +16,7 @@ from pathlib import Path # How to run this test: # -# pytest llama_stack/providers/tests/memory/test_memory.py +# pytest llama_stack/providers/tests/datasetio/test_datasetio.py # -m "meta_reference" # -v -s --tb=short --disable-warnings diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval_old.py similarity index 100% rename from llama_stack/providers/tests/eval/test_eval.py rename to llama_stack/providers/tests/eval/test_eval_old.py diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 9db70888e..41b9eb3cf 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -64,6 +64,7 @@ def inference_ollama(inference_model) -> ProviderFixture: inference_model = ( [inference_model] if isinstance(inference_model, str) else inference_model ) + print("!!!", inference_model) if "Llama3.1-8B-Instruct" in inference_model: pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py new file mode 100644 index 000000000..698d4a60a --- /dev/null +++ b/llama_stack/providers/tests/scoring/conftest.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import get_provider_fixture_overrides + +from ..datasetio.fixtures import DATASETIO_FIXTURES +from ..inference.fixtures import INFERENCE_FIXTURES +from .fixtures import SCORING_FIXTURES + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "fireworks", + }, + id="meta_reference_scoring_fireworks_inference", + marks=pytest.mark.meta_reference_scoring_fireworks_inference, + ) +] + + +def pytest_configure(config): + for fixture_name in ["meta_reference_scoring_fireworks_inference"]: + config.addinivalue_line( + "markers", + f"{fixture_name}: marks tests as {fixture_name} specific", + ) + + +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default="Llama3.2-3B-Instruct", + help="Specify the inference model to use for testing", + ) + + +def pytest_generate_tests(metafunc): + if "scoring_stack" in metafunc.fixturenames: + available_fixtures = { + "scoring": SCORING_FIXTURES, + "datasetio": DATASETIO_FIXTURES, + "inference": INFERENCE_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + metafunc.parametrize("scoring_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py new file mode 100644 index 000000000..470fad215 --- /dev/null +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..conftest import ProviderFixture, remote_stack_fixture + + +@pytest.fixture(scope="session") +def scoring_remote() -> ProviderFixture: + return remote_stack_fixture() + + +@pytest.fixture(scope="session") +def scoring_meta_reference() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config={}, + ) + ], + ) + + +SCORING_FIXTURES = ["meta_reference", "remote"] + + +@pytest_asyncio.fixture(scope="session") +async def scoring_stack(request): + fixture_dict = request.param + print("!!!", fixture_dict) + + providers = {} + provider_data = {} + for key in ["datasetio", "scoring", "inference"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) + + impls = await resolve_impls_for_test_v2( + [Api.scoring, Api.datasetio, Api.inference], + providers, + provider_data, + ) + + print(impls) + return impls[Api.scoring], impls[Api.scoring_functions] diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index b9b920739..1b50cbc38 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -3,150 +3,23 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + + import pytest -import pytest_asyncio - -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 - -from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset -from llama_stack.providers.tests.resolver import resolve_impls_for_test # How to run this test: # -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \ -# --tb=short --disable-warnings -# ``` +# pytest llama_stack/providers/tests/scoring/test_scoring.py +# -m "meta_reference" +# -v -s --tb=short --disable-warnings -@pytest_asyncio.fixture(scope="session") -async def scoring_settings(): - impls = await resolve_impls_for_test( - Api.scoring, deps=[Api.datasetio, Api.inference] - ) - return { - "scoring_impl": impls[Api.scoring], - "scoring_functions_impl": impls[Api.scoring_functions], - "datasets_impl": impls[Api.datasets], - } - - -@pytest_asyncio.fixture(scope="session") -async def provider_scoring_functions(): - return { - "meta-reference": { - "meta-reference::equality", - "meta-reference::subset_of", - "meta-reference::llm_as_judge_8b_correctness", - }, - "braintrust": { - "braintrust::factuality", - "braintrust::answer-correctness", - }, - } - - -@pytest.mark.asyncio -async def test_scoring_functions_list(scoring_settings, provider_scoring_functions): - scoring_impl = scoring_settings["scoring_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - scoring_functions = await scoring_functions_impl.list_scoring_functions() - assert isinstance(scoring_functions, list) - assert len(scoring_functions) > 0 - function_ids = [f.identifier for f in scoring_functions] - # get current provider_type we're testing - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type - - for x in provider_scoring_functions[provider_type]: - assert x in function_ids - - -@pytest.mark.asyncio -async def test_scoring_functions_register(scoring_settings): - scoring_impl = scoring_settings["scoring_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - datasets_impl = scoring_settings["datasets_impl"] - - # get current provider_type we're testing - scoring_functions = await scoring_functions_impl.list_scoring_functions() - function_ids = [f.identifier for f in scoring_functions] - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type - if provider_type not in ("meta-reference"): - pytest.skip( - "Other scoring providers don't support registering scoring functions." - ) - - test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: """ - # register the scoring function - await scoring_functions_impl.register_scoring_function( - ScoringFnDefWithProvider( - identifier="meta-reference::llm_as_judge_8b_random", - description="Llm As Judge Scoring Function", - parameters=[], - return_type=NumberType(), - context=LLMAsJudgeContext( - prompt_template=test_prompt, - judge_model="Llama3.1-8B-Instruct", - judge_score_regex=[r"Number: (\d+)"], - ), - provider_id="test-meta", - ) - ) - - scoring_functions = await scoring_functions_impl.list_scoring_functions() - assert isinstance(scoring_functions, list) - assert len(scoring_functions) > 0 - function_ids = [f.identifier for f in scoring_functions] - assert "meta-reference::llm_as_judge_8b_random" in function_ids - - # test score using newly registered scoring function - await register_dataset(datasets_impl) - response = await datasets_impl.list_datasets() - assert len(response) == 1 - response = await scoring_impl.score_batch( - dataset_id=response[0].identifier, - scoring_functions=[ - "meta-reference::llm_as_judge_8b_random", - ], - ) - assert "meta-reference::llm_as_judge_8b_random" in response.results - - -@pytest.mark.asyncio -async def test_scoring_score(scoring_settings, provider_scoring_functions): - scoring_impl = scoring_settings["scoring_impl"] - datasets_impl = scoring_settings["datasets_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - await register_dataset(datasets_impl) - - response = await datasets_impl.list_datasets() - assert len(response) == 1 - - # get current provider_type we're testing - scoring_functions = await scoring_functions_impl.list_scoring_functions() - function_ids = [f.identifier for f in scoring_functions] - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type - - response = await scoring_impl.score_batch( - dataset_id=response[0].identifier, - scoring_functions=list(provider_scoring_functions[provider_type]), - ) - - assert len(response.results) == len(provider_scoring_functions[provider_type]) - for x in provider_scoring_functions[provider_type]: - assert x in response.results +class TestScoring: + @pytest.mark.asyncio + async def test_scoring_functions_list(self, scoring_stack): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + _, scoring_functions_impl = scoring_stack + # response = await datasets_impl.list_datasets() + # assert isinstance(response, list) + # assert len(response) == 0 diff --git a/llama_stack/providers/tests/scoring/test_scoring_old.py b/llama_stack/providers/tests/scoring/test_scoring_old.py new file mode 100644 index 000000000..b9b920739 --- /dev/null +++ b/llama_stack/providers/tests/scoring/test_scoring_old.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import pytest +import pytest_asyncio + +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 + +from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset +from llama_stack.providers.tests.resolver import resolve_impls_for_test + +# How to run this test: +# +# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky +# since it depends on the provider you are testing. On top of that you need +# `pytest` and `pytest-asyncio` installed. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \ +# --tb=short --disable-warnings +# ``` + + +@pytest_asyncio.fixture(scope="session") +async def scoring_settings(): + impls = await resolve_impls_for_test( + Api.scoring, deps=[Api.datasetio, Api.inference] + ) + return { + "scoring_impl": impls[Api.scoring], + "scoring_functions_impl": impls[Api.scoring_functions], + "datasets_impl": impls[Api.datasets], + } + + +@pytest_asyncio.fixture(scope="session") +async def provider_scoring_functions(): + return { + "meta-reference": { + "meta-reference::equality", + "meta-reference::subset_of", + "meta-reference::llm_as_judge_8b_correctness", + }, + "braintrust": { + "braintrust::factuality", + "braintrust::answer-correctness", + }, + } + + +@pytest.mark.asyncio +async def test_scoring_functions_list(scoring_settings, provider_scoring_functions): + scoring_impl = scoring_settings["scoring_impl"] + scoring_functions_impl = scoring_settings["scoring_functions_impl"] + scoring_functions = await scoring_functions_impl.list_scoring_functions() + assert isinstance(scoring_functions, list) + assert len(scoring_functions) > 0 + function_ids = [f.identifier for f in scoring_functions] + # get current provider_type we're testing + provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) + provider_type = provider.__provider_spec__.provider_type + + for x in provider_scoring_functions[provider_type]: + assert x in function_ids + + +@pytest.mark.asyncio +async def test_scoring_functions_register(scoring_settings): + scoring_impl = scoring_settings["scoring_impl"] + scoring_functions_impl = scoring_settings["scoring_functions_impl"] + datasets_impl = scoring_settings["datasets_impl"] + + # get current provider_type we're testing + scoring_functions = await scoring_functions_impl.list_scoring_functions() + function_ids = [f.identifier for f in scoring_functions] + provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) + provider_type = provider.__provider_spec__.provider_type + if provider_type not in ("meta-reference"): + pytest.skip( + "Other scoring providers don't support registering scoring functions." + ) + + test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: """ + # register the scoring function + await scoring_functions_impl.register_scoring_function( + ScoringFnDefWithProvider( + identifier="meta-reference::llm_as_judge_8b_random", + description="Llm As Judge Scoring Function", + parameters=[], + return_type=NumberType(), + context=LLMAsJudgeContext( + prompt_template=test_prompt, + judge_model="Llama3.1-8B-Instruct", + judge_score_regex=[r"Number: (\d+)"], + ), + provider_id="test-meta", + ) + ) + + scoring_functions = await scoring_functions_impl.list_scoring_functions() + assert isinstance(scoring_functions, list) + assert len(scoring_functions) > 0 + function_ids = [f.identifier for f in scoring_functions] + assert "meta-reference::llm_as_judge_8b_random" in function_ids + + # test score using newly registered scoring function + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert len(response) == 1 + response = await scoring_impl.score_batch( + dataset_id=response[0].identifier, + scoring_functions=[ + "meta-reference::llm_as_judge_8b_random", + ], + ) + assert "meta-reference::llm_as_judge_8b_random" in response.results + + +@pytest.mark.asyncio +async def test_scoring_score(scoring_settings, provider_scoring_functions): + scoring_impl = scoring_settings["scoring_impl"] + datasets_impl = scoring_settings["datasets_impl"] + scoring_functions_impl = scoring_settings["scoring_functions_impl"] + await register_dataset(datasets_impl) + + response = await datasets_impl.list_datasets() + assert len(response) == 1 + + # get current provider_type we're testing + scoring_functions = await scoring_functions_impl.list_scoring_functions() + function_ids = [f.identifier for f in scoring_functions] + provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) + provider_type = provider.__provider_spec__.provider_type + + response = await scoring_impl.score_batch( + dataset_id=response[0].identifier, + scoring_functions=list(provider_scoring_functions[provider_type]), + ) + + assert len(response.results) == len(provider_scoring_functions[provider_type]) + for x in provider_scoring_functions[provider_type]: + assert x in response.results