remove braintrust scoring_fn

This commit is contained in:
Xi Yan 2024-10-28 12:41:06 -07:00
parent b5ed80ac15
commit f5d41b582e
2 changed files with 55 additions and 92 deletions

View file

@ -11,12 +11,23 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.answer_correctness import (
answer_correctness_fn_def,
)
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.factuality import (
factuality_fn_def,
)
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_average,
)
from .config import BraintrustScoringConfig
from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
@ -28,26 +39,31 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.braintrust_scoring_fn_impl = None
self.supported_fn_ids = {}
async def initialize(self) -> None:
self.braintrust_scoring_fn_impl = BraintrustScoringFn()
self.supported_fn_ids = {
x.identifier
for x in self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs()
self.braintrust_evaluators = {
"braintrust::factuality": Factuality(),
"braintrust::answer-correctness": AnswerCorrectness(),
}
self.supported_fn_defs_registry = {
factuality_fn_def.identifier: factuality_fn_def,
answer_correctness_fn_def.identifier: answer_correctness_fn_def,
}
# self.braintrust_scoring_fn_impl = None
# self.supported_fn_ids = {}
async def initialize(self) -> None: ...
# self.braintrust_scoring_fn_impl = BraintrustScoringFn()
# self.supported_fn_ids = {
# x.identifier
# for x in self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs()
# }
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFnDef]:
assert (
self.braintrust_scoring_fn_impl is not None
), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. "
scoring_fn_defs_list = (
self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs()
)
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
for f in scoring_fn_defs_list:
assert f.identifier.startswith(
"braintrust"
@ -100,22 +116,37 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
results=res.results,
)
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
input_query = input_row["input_query"]
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
result = evaluator(generated_answer, expected_answer, input=input_query)
score = result.score
return {"score": score, "metadata": result.metadata}
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
assert (
self.braintrust_scoring_fn_impl is not None
), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. "
# assert (
# self.braintrust_scoring_fn_impl is not None
# ), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. "
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in self.supported_fn_ids:
if scoring_fn_id not in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
score_results = await self.braintrust_scoring_fn_impl.score(
input_rows, scoring_fn_id
)
agg_results = await self.braintrust_scoring_fn_impl.aggregate(score_results)
score_results = [
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
agg_results = aggregate_average(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -1,68 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from typing import Any, Dict, List, Optional
# TODO: move the common base out from meta-reference into common
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
BaseScoringFn,
)
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
aggregate_average,
)
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.answer_correctness import (
answer_correctness_fn_def,
)
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.factuality import (
factuality_fn_def,
)
BRAINTRUST_FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
class BraintrustScoringFn(BaseScoringFn):
"""
Test whether an output is factual, compared to an original (`expected`) value.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.braintrust_evaluators = {
"braintrust::factuality": Factuality(),
"braintrust::answer-correctness": AnswerCorrectness(),
}
self.supported_fn_defs_registry = {
factuality_fn_def.identifier: factuality_fn_def,
answer_correctness_fn_def.identifier: answer_correctness_fn_def,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
input_query = input_row["input_query"]
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
result = evaluator(generated_answer, expected_answer, input=input_query)
score = result.score
return {"score": score, "metadata": result.metadata}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_average(scoring_results)