mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
remove braintrust scoring_fn
This commit is contained in:
parent
b5ed80ac15
commit
f5d41b582e
2 changed files with 55 additions and 92 deletions
|
@ -11,12 +11,23 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
|
|||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
|
||||
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
|
||||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import AnswerCorrectness
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.answer_correctness import (
|
||||
answer_correctness_fn_def,
|
||||
)
|
||||
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.factuality import (
|
||||
factuality_fn_def,
|
||||
)
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_average,
|
||||
)
|
||||
|
||||
from .config import BraintrustScoringConfig
|
||||
|
||||
from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
|
||||
|
||||
|
||||
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
def __init__(
|
||||
|
@ -28,26 +39,31 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.braintrust_scoring_fn_impl = None
|
||||
self.supported_fn_ids = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.braintrust_scoring_fn_impl = BraintrustScoringFn()
|
||||
self.supported_fn_ids = {
|
||||
x.identifier
|
||||
for x in self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs()
|
||||
self.braintrust_evaluators = {
|
||||
"braintrust::factuality": Factuality(),
|
||||
"braintrust::answer-correctness": AnswerCorrectness(),
|
||||
}
|
||||
self.supported_fn_defs_registry = {
|
||||
factuality_fn_def.identifier: factuality_fn_def,
|
||||
answer_correctness_fn_def.identifier: answer_correctness_fn_def,
|
||||
}
|
||||
|
||||
# self.braintrust_scoring_fn_impl = None
|
||||
# self.supported_fn_ids = {}
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
# self.braintrust_scoring_fn_impl = BraintrustScoringFn()
|
||||
# self.supported_fn_ids = {
|
||||
# x.identifier
|
||||
# for x in self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs()
|
||||
# }
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||
assert (
|
||||
self.braintrust_scoring_fn_impl is not None
|
||||
), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. "
|
||||
scoring_fn_defs_list = (
|
||||
self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs()
|
||||
)
|
||||
|
||||
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith(
|
||||
"braintrust"
|
||||
|
@ -100,22 +116,37 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
results=res.results,
|
||||
)
|
||||
|
||||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
input_query = input_row["input_query"]
|
||||
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
|
||||
|
||||
result = evaluator(generated_answer, expected_answer, input=input_query)
|
||||
score = result.score
|
||||
return {"score": score, "metadata": result.metadata}
|
||||
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
) -> ScoreResponse:
|
||||
assert (
|
||||
self.braintrust_scoring_fn_impl is not None
|
||||
), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. "
|
||||
# assert (
|
||||
# self.braintrust_scoring_fn_impl is not None
|
||||
# ), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. "
|
||||
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions:
|
||||
if scoring_fn_id not in self.supported_fn_ids:
|
||||
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
|
||||
score_results = await self.braintrust_scoring_fn_impl.score(
|
||||
input_rows, scoring_fn_id
|
||||
)
|
||||
agg_results = await self.braintrust_scoring_fn_impl.aggregate(score_results)
|
||||
score_results = [
|
||||
await self.score_row(input_row, scoring_fn_id)
|
||||
for input_row in input_rows
|
||||
]
|
||||
|
||||
agg_results = aggregate_average(score_results)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
|
|
|
@ -1,68 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# TODO: move the common base out from meta-reference into common
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||
BaseScoringFn,
|
||||
)
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||
aggregate_average,
|
||||
)
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import AnswerCorrectness
|
||||
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.answer_correctness import (
|
||||
answer_correctness_fn_def,
|
||||
)
|
||||
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.factuality import (
|
||||
factuality_fn_def,
|
||||
)
|
||||
|
||||
|
||||
BRAINTRUST_FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
|
||||
|
||||
|
||||
class BraintrustScoringFn(BaseScoringFn):
|
||||
"""
|
||||
Test whether an output is factual, compared to an original (`expected`) value.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.braintrust_evaluators = {
|
||||
"braintrust::factuality": Factuality(),
|
||||
"braintrust::answer-correctness": AnswerCorrectness(),
|
||||
}
|
||||
self.supported_fn_defs_registry = {
|
||||
factuality_fn_def.identifier: factuality_fn_def,
|
||||
answer_correctness_fn_def.identifier: answer_correctness_fn_def,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
input_query = input_row["input_query"]
|
||||
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
|
||||
|
||||
result = evaluator(generated_answer, expected_answer, input=input_query)
|
||||
score = result.score
|
||||
return {"score": score, "metadata": result.metadata}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_average(scoring_results)
|
Loading…
Add table
Add a link
Reference in a new issue