diff --git a/llama_stack/providers/impls/braintrust/scoring/braintrust.py b/llama_stack/providers/impls/braintrust/scoring/braintrust.py index 3b8263dfb..3d0c1d751 100644 --- a/llama_stack/providers/impls/braintrust/scoring/braintrust.py +++ b/llama_stack/providers/impls/braintrust/scoring/braintrust.py @@ -11,12 +11,23 @@ from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 + +# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn +from autoevals.llm import Factuality +from autoevals.ragas import AnswerCorrectness from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.answer_correctness import ( + answer_correctness_fn_def, +) +from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.factuality import ( + factuality_fn_def, +) +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( + aggregate_average, +) from .config import BraintrustScoringConfig -from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn - class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): def __init__( @@ -28,26 +39,31 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets_api - self.braintrust_scoring_fn_impl = None - self.supported_fn_ids = {} - async def initialize(self) -> None: - self.braintrust_scoring_fn_impl = BraintrustScoringFn() - self.supported_fn_ids = { - x.identifier - for x in self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs() + self.braintrust_evaluators = { + "braintrust::factuality": Factuality(), + "braintrust::answer-correctness": AnswerCorrectness(), } + self.supported_fn_defs_registry = { + factuality_fn_def.identifier: factuality_fn_def, + answer_correctness_fn_def.identifier: answer_correctness_fn_def, + } + + # self.braintrust_scoring_fn_impl = None + # self.supported_fn_ids = {} + + async def initialize(self) -> None: ... + + # self.braintrust_scoring_fn_impl = BraintrustScoringFn() + # self.supported_fn_ids = { + # x.identifier + # for x in self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs() + # } async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFnDef]: - assert ( - self.braintrust_scoring_fn_impl is not None - ), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. " - scoring_fn_defs_list = ( - self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs() - ) - + scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()] for f in scoring_fn_defs_list: assert f.identifier.startswith( "braintrust" @@ -100,22 +116,37 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): results=res.results, ) + async def score_row( + self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None + ) -> ScoringResultRow: + assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + input_query = input_row["input_query"] + evaluator = self.braintrust_evaluators[scoring_fn_identifier] + + result = evaluator(generated_answer, expected_answer, input=input_query) + score = result.score + return {"score": score, "metadata": result.metadata} + async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] ) -> ScoreResponse: - assert ( - self.braintrust_scoring_fn_impl is not None - ), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. " + # assert ( + # self.braintrust_scoring_fn_impl is not None + # ), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. " res = {} for scoring_fn_id in scoring_functions: - if scoring_fn_id not in self.supported_fn_ids: + if scoring_fn_id not in self.supported_fn_defs_registry: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") - score_results = await self.braintrust_scoring_fn_impl.score( - input_rows, scoring_fn_id - ) - agg_results = await self.braintrust_scoring_fn_impl.aggregate(score_results) + score_results = [ + await self.score_row(input_row, scoring_fn_id) + for input_row in input_rows + ] + + agg_results = aggregate_average(score_results) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/braintrust_scoring_fn.py b/llama_stack/providers/impls/braintrust/scoring/scoring_fn/braintrust_scoring_fn.py deleted file mode 100644 index fbf9e0bf8..000000000 --- a/llama_stack/providers/impls/braintrust/scoring/scoring_fn/braintrust_scoring_fn.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from pathlib import Path - -from typing import Any, Dict, List, Optional - -# TODO: move the common base out from meta-reference into common -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( - BaseScoringFn, -) -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( - aggregate_average, -) -from llama_stack.apis.scoring import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 -from llama_stack.apis.common.type_system import * # noqa: F403 -from autoevals.llm import Factuality -from autoevals.ragas import AnswerCorrectness -from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.answer_correctness import ( - answer_correctness_fn_def, -) -from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.factuality import ( - factuality_fn_def, -) - - -BRAINTRUST_FN_DEFS_PATH = Path(__file__).parent / "fn_defs" - - -class BraintrustScoringFn(BaseScoringFn): - """ - Test whether an output is factual, compared to an original (`expected`) value. - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.braintrust_evaluators = { - "braintrust::factuality": Factuality(), - "braintrust::answer-correctness": AnswerCorrectness(), - } - self.supported_fn_defs_registry = { - factuality_fn_def.identifier: factuality_fn_def, - answer_correctness_fn_def.identifier: answer_correctness_fn_def, - } - - async def score_row( - self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - ) -> ScoringResultRow: - assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" - expected_answer = input_row["expected_answer"] - generated_answer = input_row["generated_answer"] - input_query = input_row["input_query"] - evaluator = self.braintrust_evaluators[scoring_fn_identifier] - - result = evaluator(generated_answer, expected_answer, input=input_query) - score = result.score - return {"score": score, "metadata": result.metadata} - - async def aggregate( - self, scoring_results: List[ScoringResultRow] - ) -> Dict[str, Any]: - return aggregate_average(scoring_results)