From 38186f7903b6df31f5c09bf145bdcc31bbc5b1e7 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sun, 27 Oct 2024 17:24:10 -0700 Subject: [PATCH] braintrust provider --- .../impls/braintrust/scoring/braintrust.py | 84 +++++++++---------- .../scoring_fn/braintrust_scoring_fn.py | 61 ++++++++++++++ .../fn_defs/answer-correctness.json | 10 +++ .../scoring_fn/fn_defs/factuality.json | 2 +- .../tests/datasetio/test_datasetio.py | 2 +- .../providers/tests/scoring/test_scoring.py | 6 +- 6 files changed, 116 insertions(+), 49 deletions(-) create mode 100644 llama_stack/providers/impls/braintrust/scoring/scoring_fn/braintrust_scoring_fn.py create mode 100644 llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/answer-correctness.json diff --git a/llama_stack/providers/impls/braintrust/scoring/braintrust.py b/llama_stack/providers/impls/braintrust/scoring/braintrust.py index ba8ca9e15..c8d443337 100644 --- a/llama_stack/providers/impls/braintrust/scoring/braintrust.py +++ b/llama_stack/providers/impls/braintrust/scoring/braintrust.py @@ -11,12 +11,12 @@ from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 -from autoevals.llm import Factuality -from autoevals.ragas import AnswerCorrectness from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from .config import BraintrustScoringConfig +from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn + class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): def __init__( @@ -28,36 +28,29 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets_api - self.scoring_fn_id_impls = {} + self.braintrust_scoring_fn_impl = None + self.supported_fn_ids = {} async def initialize(self) -> None: - self.scoring_fn_id_impls = { - "braintrust::factuality": Factuality(), - "braintrust::answer-correctness": AnswerCorrectness(), + self.braintrust_scoring_fn_impl = BraintrustScoringFn() + await self.braintrust_scoring_fn_impl.initialize() + self.supported_fn_ids = { + x.identifier + for x in self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs() } async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFnDef]: - return [ - ScoringFnDef( - identifier="braintrust::factuality", - description="Test whether an output is factual, compared to an original (`expected`) value.", - parameters=[], - return_type=NumberType(), - ), - ScoringFnDef( - identifier="braintrust::answer-correctness", - description="Test whether an output is factual, compared to an original (`expected`) value.", - parameters=[], - return_type=NumberType(), - ), - ] + assert ( + self.braintrust_scoring_fn_impl is not None + ), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. " + return self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs() async def register_scoring_function(self, function_def: ScoringFnDef) -> None: - # self.llm_as_judge_fn.register_scoring_fn_def(function_def) - # self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn - return None + raise NotImplementedError( + "Registering scoring function not allowed for braintrust provider" + ) async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) @@ -82,19 +75,18 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): scoring_functions: List[str], save_results_dataset: bool = False, ) -> ScoreBatchResponse: - print("score_batch") await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) - # all_rows = await self.datasetio_api.get_rows_paginated( - # dataset_id=dataset_id, - # rows_in_page=-1, - # ) - # res = await self.score( - # input_rows=all_rows.rows, scoring_functions=scoring_functions - # ) - # if save_results_dataset: - # # TODO: persist and register dataset on to server for reading - # # self.datasets_api.register_dataset() - # raise NotImplementedError("Save results dataset not implemented yet") + all_rows = await self.datasetio_api.get_rows_paginated( + dataset_id=dataset_id, + rows_in_page=-1, + ) + res = await self.score( + input_rows=all_rows.rows, scoring_functions=scoring_functions + ) + if save_results_dataset: + # TODO: persist and register dataset on to server for reading + # self.datasets_api.register_dataset() + raise NotImplementedError("Save results dataset not implemented yet") return ScoreBatchResponse( results=res.results, @@ -103,19 +95,25 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] ) -> ScoreResponse: + assert ( + self.braintrust_scoring_fn_impl is not None + ), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. " + res = {} - print("score") for scoring_fn_id in scoring_functions: - if scoring_fn_id not in self.scoring_fn_id_impls: + if scoring_fn_id not in self.supported_fn_ids: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") + # scoring_impl = self.scoring_fn_id_impls[scoring_fn_id] # scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] - # score_results = await scoring_fn.score(input_rows, scoring_fn_id) - # agg_results = await scoring_fn.aggregate(score_results) - # res[scoring_fn_id] = ScoringResult( - # score_rows=score_results, - # aggregated_results=agg_results, - # ) + score_results = await self.braintrust_scoring_fn_impl.score( + input_rows, scoring_fn_id + ) + agg_results = await self.braintrust_scoring_fn_impl.aggregate(score_results) + res[scoring_fn_id] = ScoringResult( + score_rows=score_results, + aggregated_results=agg_results, + ) return ScoreResponse( results=res, diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/braintrust_scoring_fn.py b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/braintrust_scoring_fn.py new file mode 100644 index 000000000..4663886a5 --- /dev/null +++ b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/braintrust_scoring_fn.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from typing import Any, Dict, List, Optional + +# TODO: move the common base out from meta-reference into common +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( + BaseScoringFn, +) +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( + aggregate_average, +) +from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from autoevals.llm import Factuality +from autoevals.ragas import AnswerCorrectness + + +BRAINTRUST_FN_DEFS_PATH = Path(__file__).parent / "fn_defs" + + +class BraintrustScoringFn(BaseScoringFn): + """ + Test whether an output is factual, compared to an original (`expected`) value. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.braintrust_evaluators = { + "braintrust::factuality": Factuality(), + "braintrust::answer-correctness": AnswerCorrectness(), + } + self.defs_paths = [ + str(x) for x in sorted(BRAINTRUST_FN_DEFS_PATH.glob("*.json")) + ] + + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + ) -> ScoringResultRow: + assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + input_query = input_row["input_query"] + evaluator = self.braintrust_evaluators[scoring_fn_identifier] + + result = evaluator(generated_answer, expected_answer, input=input_query) + score = result.score + return {"score": score, "metadata": result.metadata} + + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: + return aggregate_average(scoring_results) diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/answer-correctness.json b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/answer-correctness.json new file mode 100644 index 000000000..3fc2957a3 --- /dev/null +++ b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/answer-correctness.json @@ -0,0 +1,10 @@ +{ + "identifier": "braintrust::answer-correctness", + "description": "Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", + "metadata": {}, + "parameters": [], + "return_type": { + "type": "number" + }, + "context": null +} diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.json b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.json index 0f777f98d..210901d6f 100644 --- a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.json +++ b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/fn_defs/factuality.json @@ -1,6 +1,6 @@ { "identifier": "braintrust::factuality", - "description": "Test whether an output is factual, compared to an original (`expected`) value. ", + "description": "Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", "metadata": {}, "parameters": [], "return_type": { diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index d24ea5ee2..adfe55896 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -82,7 +82,7 @@ async def register_dataset( dataset = DatasetDefWithProvider( identifier=dataset_id, - provider_id=os.environ["PROVIDER_ID"] or os.environ["DATASETIO_PROVIDER_ID"], + provider_id=os.environ["DATASETIO_PROVIDER_ID"] or os.environ["PROVIDER_ID"], url=URL( uri=test_url, ), diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 423d8a573..68ba20e17 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -141,16 +141,14 @@ async def test_scoring_score(scoring_settings, provider_scoring_functions): function_ids = [f.identifier for f in scoring_functions] provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) provider_type = provider.__provider_spec__.provider_type - if provider_type not in ("meta-reference"): - pytest.skip( - "Other scoring providers don't support registering scoring functions." - ) response = await scoring_impl.score_batch( dataset_id=response[0].identifier, scoring_functions=list(provider_scoring_functions[provider_type]), ) + print("RESPONSE", response) + assert len(response.results) == len(provider_scoring_functions[provider_type]) for x in provider_scoring_functions[provider_type]: assert x in response.results