From ed833bb758d347db1c5b6194b5e84cce24158873 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 28 Oct 2024 18:59:35 -0700 Subject: [PATCH] [Evals API][7/n] braintrust scoring provider (#333) * wip scoring refactor * llm as judge, move folders * test full generation + eval * extract score regex to llm context * remove prints, cleanup braintrust in this branch * braintrust skeleton * datasetio test fix * braintrust provider * remove prints * dependencies * change json -> class * json -> class * remove initialize * address nits * check identifier prefix * braintrust scoring identifier check, rebase * udpate MANIFEST * manifest * remove braintrust scoring_fn * remove comments * tests * imports fix --- .../impls/braintrust/scoring/__init__.py | 21 +++ .../impls/braintrust/scoring/braintrust.py | 140 ++++++++++++++++++ .../impls/braintrust/scoring/config.py | 9 ++ .../braintrust/scoring/scoring_fn/__init__.py | 5 + .../scoring/scoring_fn/fn_defs/__init__.py | 5 + .../scoring_fn/fn_defs/answer_correctness.py | 16 ++ .../scoring/scoring_fn/fn_defs/factuality.py | 16 ++ llama_stack/providers/registry/scoring.py | 11 ++ .../tests/datasetio/test_datasetio.py | 3 +- .../scoring/provider_config_example.yaml | 3 + .../providers/tests/scoring/test_scoring.py | 60 ++++++-- 11 files changed, 274 insertions(+), 15 deletions(-) create mode 100644 llama_stack/providers/impls/braintrust/scoring/__init__.py create mode 100644 llama_stack/providers/impls/braintrust/scoring/braintrust.py create mode 100644 llama_stack/providers/impls/braintrust/scoring/config.py create mode 100644 llama_stack/providers/impls/braintrust/scoring/scoring_fn/__init__.py create mode 100644 llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/__init__.py create mode 100644 llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/answer_correctness.py create mode 100644 llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.py diff --git a/llama_stack/providers/impls/braintrust/scoring/__init__.py b/llama_stack/providers/impls/braintrust/scoring/__init__.py new file mode 100644 index 000000000..f442a6c3b --- /dev/null +++ b/llama_stack/providers/impls/braintrust/scoring/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec + +from .config import BraintrustScoringConfig + + +async def get_provider_impl( + config: BraintrustScoringConfig, + deps: Dict[Api, ProviderSpec], +): + from .braintrust import BraintrustScoringImpl + + impl = BraintrustScoringImpl(config, deps[Api.datasetio], deps[Api.datasets]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/braintrust/scoring/braintrust.py b/llama_stack/providers/impls/braintrust/scoring/braintrust.py new file mode 100644 index 000000000..826d60379 --- /dev/null +++ b/llama_stack/providers/impls/braintrust/scoring/braintrust.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import List + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 + +# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn +from autoevals.llm import Factuality +from autoevals.ragas import AnswerCorrectness +from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( + aggregate_average, +) + +from .config import BraintrustScoringConfig +from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def +from .scoring_fn.fn_defs.factuality import factuality_fn_def + + +class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): + def __init__( + self, + config: BraintrustScoringConfig, + datasetio_api: DatasetIO, + datasets_api: Datasets, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + + self.braintrust_evaluators = { + "braintrust::factuality": Factuality(), + "braintrust::answer-correctness": AnswerCorrectness(), + } + self.supported_fn_defs_registry = { + factuality_fn_def.identifier: factuality_fn_def, + answer_correctness_fn_def.identifier: answer_correctness_fn_def, + } + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def list_scoring_functions(self) -> List[ScoringFnDef]: + scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()] + for f in scoring_fn_defs_list: + assert f.identifier.startswith( + "braintrust" + ), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! " + + return scoring_fn_defs_list + + async def register_scoring_function(self, function_def: ScoringFnDef) -> None: + raise NotImplementedError( + "Registering scoring function not allowed for braintrust provider" + ) + + async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: + dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + raise ValueError( + f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." + ) + + for required_column in ["generated_answer", "expected_answer", "input_query"]: + if required_column not in dataset_def.dataset_schema: + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column." + ) + if dataset_def.dataset_schema[required_column].type != "string": + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." + ) + + async def score_batch( + self, + dataset_id: str, + scoring_functions: List[str], + save_results_dataset: bool = False, + ) -> ScoreBatchResponse: + await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + all_rows = await self.datasetio_api.get_rows_paginated( + dataset_id=dataset_id, + rows_in_page=-1, + ) + res = await self.score( + input_rows=all_rows.rows, scoring_functions=scoring_functions + ) + if save_results_dataset: + # TODO: persist and register dataset on to server for reading + # self.datasets_api.register_dataset() + raise NotImplementedError("Save results dataset not implemented yet") + + return ScoreBatchResponse( + results=res.results, + ) + + async def score_row( + self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None + ) -> ScoringResultRow: + assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + input_query = input_row["input_query"] + evaluator = self.braintrust_evaluators[scoring_fn_identifier] + + result = evaluator(generated_answer, expected_answer, input=input_query) + score = result.score + return {"score": score, "metadata": result.metadata} + + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: + res = {} + for scoring_fn_id in scoring_functions: + if scoring_fn_id not in self.supported_fn_defs_registry: + raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") + + score_results = [ + await self.score_row(input_row, scoring_fn_id) + for input_row in input_rows + ] + + agg_results = aggregate_average(score_results) + res[scoring_fn_id] = ScoringResult( + score_rows=score_results, + aggregated_results=agg_results, + ) + + return ScoreResponse( + results=res, + ) diff --git a/llama_stack/providers/impls/braintrust/scoring/config.py b/llama_stack/providers/impls/braintrust/scoring/config.py new file mode 100644 index 000000000..fef6df5c8 --- /dev/null +++ b/llama_stack/providers/impls/braintrust/scoring/config.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.scoring import * # noqa: F401, F403 + + +class BraintrustScoringConfig(BaseModel): ... diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/__init__.py b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/__init__.py b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/answer_correctness.py b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/answer_correctness.py new file mode 100644 index 000000000..ca6a46d0e --- /dev/null +++ b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/answer_correctness.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ScoringFnDef + + +answer_correctness_fn_def = ScoringFnDef( + identifier="braintrust::answer-correctness", + description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", + parameters=[], + return_type=NumberType(), +) diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.py b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.py new file mode 100644 index 000000000..cbf9cd01c --- /dev/null +++ b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ScoringFnDef + + +factuality_fn_def = ScoringFnDef( + identifier="braintrust::factuality", + description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", + parameters=[], + return_type=NumberType(), +) diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index 06983cdee..81cb47764 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -23,4 +23,15 @@ def available_providers() -> List[ProviderSpec]: Api.inference, ], ), + InlineProviderSpec( + api=Api.scoring, + provider_type="braintrust", + pip_packages=["autoevals", "openai"], + module="llama_stack.providers.impls.braintrust.scoring", + config_class="llama_stack.providers.impls.braintrust.scoring.BraintrustScoringConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + ), ] diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 743e191d4..866b1e270 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -82,7 +82,8 @@ async def register_dataset( dataset = DatasetDefWithProvider( identifier=dataset_id, - provider_id=os.environ["PROVIDER_ID"], + provider_id=os.environ.get("DATASETIO_PROVIDER_ID", None) + or os.environ["PROVIDER_ID"], url=URL( uri=test_url, ), diff --git a/llama_stack/providers/tests/scoring/provider_config_example.yaml b/llama_stack/providers/tests/scoring/provider_config_example.yaml index 9cf5713c1..6a9c0d842 100644 --- a/llama_stack/providers/tests/scoring/provider_config_example.yaml +++ b/llama_stack/providers/tests/scoring/provider_config_example.yaml @@ -7,6 +7,9 @@ providers: - provider_id: test-meta provider_type: meta-reference config: {} + - provider_id: test-braintrust + provider_type: braintrust + config: {} inference: - provider_id: tgi0 provider_type: remote::tgi diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 86deecc71..b9b920739 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -43,16 +43,35 @@ async def scoring_settings(): } +@pytest_asyncio.fixture(scope="session") +async def provider_scoring_functions(): + return { + "meta-reference": { + "meta-reference::equality", + "meta-reference::subset_of", + "meta-reference::llm_as_judge_8b_correctness", + }, + "braintrust": { + "braintrust::factuality", + "braintrust::answer-correctness", + }, + } + + @pytest.mark.asyncio -async def test_scoring_functions_list(scoring_settings): +async def test_scoring_functions_list(scoring_settings, provider_scoring_functions): + scoring_impl = scoring_settings["scoring_impl"] scoring_functions_impl = scoring_settings["scoring_functions_impl"] scoring_functions = await scoring_functions_impl.list_scoring_functions() assert isinstance(scoring_functions, list) assert len(scoring_functions) > 0 function_ids = [f.identifier for f in scoring_functions] - assert "meta-reference::equality" in function_ids - assert "meta-reference::subset_of" in function_ids - assert "meta-reference::llm_as_judge_8b_correctness" in function_ids + # get current provider_type we're testing + provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) + provider_type = provider.__provider_spec__.provider_type + + for x in provider_scoring_functions[provider_type]: + assert x in function_ids @pytest.mark.asyncio @@ -60,6 +79,17 @@ async def test_scoring_functions_register(scoring_settings): scoring_impl = scoring_settings["scoring_impl"] scoring_functions_impl = scoring_settings["scoring_functions_impl"] datasets_impl = scoring_settings["datasets_impl"] + + # get current provider_type we're testing + scoring_functions = await scoring_functions_impl.list_scoring_functions() + function_ids = [f.identifier for f in scoring_functions] + provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) + provider_type = provider.__provider_spec__.provider_type + if provider_type not in ("meta-reference"): + pytest.skip( + "Other scoring providers don't support registering scoring functions." + ) + test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: """ # register the scoring function await scoring_functions_impl.register_scoring_function( @@ -97,24 +127,26 @@ async def test_scoring_functions_register(scoring_settings): @pytest.mark.asyncio -async def test_scoring_score(scoring_settings): +async def test_scoring_score(scoring_settings, provider_scoring_functions): scoring_impl = scoring_settings["scoring_impl"] datasets_impl = scoring_settings["datasets_impl"] + scoring_functions_impl = scoring_settings["scoring_functions_impl"] await register_dataset(datasets_impl) response = await datasets_impl.list_datasets() assert len(response) == 1 + # get current provider_type we're testing + scoring_functions = await scoring_functions_impl.list_scoring_functions() + function_ids = [f.identifier for f in scoring_functions] + provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) + provider_type = provider.__provider_spec__.provider_type + response = await scoring_impl.score_batch( dataset_id=response[0].identifier, - scoring_functions=[ - "meta-reference::equality", - "meta-reference::subset_of", - "meta-reference::llm_as_judge_8b_correctness", - ], + scoring_functions=list(provider_scoring_functions[provider_type]), ) - assert len(response.results) == 3 - assert "meta-reference::equality" in response.results - assert "meta-reference::subset_of" in response.results - assert "meta-reference::llm_as_judge_8b_correctness" in response.results + assert len(response.results) == len(provider_scoring_functions[provider_type]) + for x in provider_scoring_functions[provider_type]: + assert x in response.results