From 68346fac394a86f1d1d9b40b06eac16e94806f26 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sun, 27 Oct 2024 16:43:02 -0700 Subject: [PATCH] datasetio test fix --- MANIFEST.in | 2 +- .../impls/braintrust/scoring/braintrust.py | 63 +++++++++--------- .../braintrust/scoring/scoring_fn/__init__.py | 5 ++ .../scoring_fn/fn_defs/factuality.json | 10 +++ .../tests/datasetio/test_datasetio.py | 2 +- .../providers/tests/scoring/test_scoring.py | 64 +++++++++++++++---- 6 files changed, 100 insertions(+), 46 deletions(-) create mode 100644 llama_stack/providers/impls/braintrust/scoring/scoring_fn/__init__.py create mode 100644 llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.json diff --git a/MANIFEST.in b/MANIFEST.in index 09c34cec5..56cb4440c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,4 +2,4 @@ include requirements.txt include llama_stack/distribution/*.sh include llama_stack/cli/scripts/*.sh include llama_stack/templates/*/build.yaml -include llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/*.json +include llama_stack/providers/impls/*/scoring/scoring_fn/fn_defs/*.json diff --git a/llama_stack/providers/impls/braintrust/scoring/braintrust.py b/llama_stack/providers/impls/braintrust/scoring/braintrust.py index b459c3718..ba8ca9e15 100644 --- a/llama_stack/providers/impls/braintrust/scoring/braintrust.py +++ b/llama_stack/providers/impls/braintrust/scoring/braintrust.py @@ -11,6 +11,8 @@ from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 +from autoevals.llm import Factuality +from autoevals.ragas import AnswerCorrectness from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from .config import BraintrustScoringConfig @@ -28,29 +30,29 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): self.datasets_api = datasets_api self.scoring_fn_id_impls = {} - async def initialize(self) -> None: ... - - # for x in FIXED_FNS: - # impl = x() - # await impl.initialize() - # for fn_defs in impl.get_supported_scoring_fn_defs(): - # self.scoring_fn_id_impls[fn_defs.identifier] = impl - # for x in LLM_JUDGE_FNS: - # impl = x(inference_api=self.inference_api) - # await impl.initialize() - # for fn_defs in impl.get_supported_scoring_fn_defs(): - # self.scoring_fn_id_impls[fn_defs.identifier] = impl - # self.llm_as_judge_fn = impl + async def initialize(self) -> None: + self.scoring_fn_id_impls = { + "braintrust::factuality": Factuality(), + "braintrust::answer-correctness": AnswerCorrectness(), + } async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFnDef]: - return [] - # return [ - # fn_defs - # for impl in self.scoring_fn_id_impls.values() - # for fn_defs in impl.get_supported_scoring_fn_defs() - # ] + return [ + ScoringFnDef( + identifier="braintrust::factuality", + description="Test whether an output is factual, compared to an original (`expected`) value.", + parameters=[], + return_type=NumberType(), + ), + ScoringFnDef( + identifier="braintrust::answer-correctness", + description="Test whether an output is factual, compared to an original (`expected`) value.", + parameters=[], + return_type=NumberType(), + ), + ] async def register_scoring_function(self, function_def: ScoringFnDef) -> None: # self.llm_as_judge_fn.register_scoring_fn_def(function_def) @@ -81,7 +83,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): save_results_dataset: bool = False, ) -> ScoreBatchResponse: print("score_batch") - # await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) # all_rows = await self.datasetio_api.get_rows_paginated( # dataset_id=dataset_id, # rows_in_page=-1, @@ -103,16 +105,17 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): ) -> ScoreResponse: res = {} print("score") - # for scoring_fn_id in scoring_functions: - # if scoring_fn_id not in self.scoring_fn_id_impls: - # raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") - # scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] - # score_results = await scoring_fn.score(input_rows, scoring_fn_id) - # agg_results = await scoring_fn.aggregate(score_results) - # res[scoring_fn_id] = ScoringResult( - # score_rows=score_results, - # aggregated_results=agg_results, - # ) + for scoring_fn_id in scoring_functions: + if scoring_fn_id not in self.scoring_fn_id_impls: + raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") + + # scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] + # score_results = await scoring_fn.score(input_rows, scoring_fn_id) + # agg_results = await scoring_fn.aggregate(score_results) + # res[scoring_fn_id] = ScoringResult( + # score_rows=score_results, + # aggregated_results=agg_results, + # ) return ScoreResponse( results=res, diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/__init__.py b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.json b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.json new file mode 100644 index 000000000..0f777f98d --- /dev/null +++ b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.json @@ -0,0 +1,10 @@ +{ + "identifier": "braintrust::factuality", + "description": "Test whether an output is factual, compared to an original (`expected`) value. ", + "metadata": {}, + "parameters": [], + "return_type": { + "type": "number" + }, + "context": null +} diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 743e191d4..d24ea5ee2 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -82,7 +82,7 @@ async def register_dataset( dataset = DatasetDefWithProvider( identifier=dataset_id, - provider_id=os.environ["PROVIDER_ID"], + provider_id=os.environ["PROVIDER_ID"] or os.environ["DATASETIO_PROVIDER_ID"], url=URL( uri=test_url, ), diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 86deecc71..423d8a573 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -43,16 +43,35 @@ async def scoring_settings(): } +@pytest_asyncio.fixture(scope="session") +async def provider_scoring_functions(): + return { + "meta-reference": { + "meta-reference::equality", + "meta-reference::subset_of", + "meta-reference::llm_as_judge_8b_correctness", + }, + "braintrust": { + "braintrust::factuality", + "braintrust::answer-correctness", + }, + } + + @pytest.mark.asyncio -async def test_scoring_functions_list(scoring_settings): +async def test_scoring_functions_list(scoring_settings, provider_scoring_functions): + scoring_impl = scoring_settings["scoring_impl"] scoring_functions_impl = scoring_settings["scoring_functions_impl"] scoring_functions = await scoring_functions_impl.list_scoring_functions() assert isinstance(scoring_functions, list) assert len(scoring_functions) > 0 function_ids = [f.identifier for f in scoring_functions] - assert "meta-reference::equality" in function_ids - assert "meta-reference::subset_of" in function_ids - assert "meta-reference::llm_as_judge_8b_correctness" in function_ids + # get current provider_type we're testing + provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) + provider_type = provider.__provider_spec__.provider_type + + for x in provider_scoring_functions[provider_type]: + assert x in function_ids @pytest.mark.asyncio @@ -60,6 +79,17 @@ async def test_scoring_functions_register(scoring_settings): scoring_impl = scoring_settings["scoring_impl"] scoring_functions_impl = scoring_settings["scoring_functions_impl"] datasets_impl = scoring_settings["datasets_impl"] + + # get current provider_type we're testing + scoring_functions = await scoring_functions_impl.list_scoring_functions() + function_ids = [f.identifier for f in scoring_functions] + provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) + provider_type = provider.__provider_spec__.provider_type + if provider_type not in ("meta-reference"): + pytest.skip( + "Other scoring providers don't support registering scoring functions." + ) + test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: """ # register the scoring function await scoring_functions_impl.register_scoring_function( @@ -97,24 +127,30 @@ async def test_scoring_functions_register(scoring_settings): @pytest.mark.asyncio -async def test_scoring_score(scoring_settings): +async def test_scoring_score(scoring_settings, provider_scoring_functions): scoring_impl = scoring_settings["scoring_impl"] datasets_impl = scoring_settings["datasets_impl"] + scoring_functions_impl = scoring_settings["scoring_functions_impl"] await register_dataset(datasets_impl) response = await datasets_impl.list_datasets() assert len(response) == 1 + # get current provider_type we're testing + scoring_functions = await scoring_functions_impl.list_scoring_functions() + function_ids = [f.identifier for f in scoring_functions] + provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) + provider_type = provider.__provider_spec__.provider_type + if provider_type not in ("meta-reference"): + pytest.skip( + "Other scoring providers don't support registering scoring functions." + ) + response = await scoring_impl.score_batch( dataset_id=response[0].identifier, - scoring_functions=[ - "meta-reference::equality", - "meta-reference::subset_of", - "meta-reference::llm_as_judge_8b_correctness", - ], + scoring_functions=list(provider_scoring_functions[provider_type]), ) - assert len(response.results) == 3 - assert "meta-reference::equality" in response.results - assert "meta-reference::subset_of" in response.results - assert "meta-reference::llm_as_judge_8b_correctness" in response.results + assert len(response.results) == len(provider_scoring_functions[provider_type]) + for x in provider_scoring_functions[provider_type]: + assert x in response.results