mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
remove braintrust scoring_fn
This commit is contained in:
parent
b5ed80ac15
commit
f5d41b582e
2 changed files with 55 additions and 92 deletions
|
@ -11,12 +11,23 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
|
||||||
|
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
|
||||||
|
from autoevals.llm import Factuality
|
||||||
|
from autoevals.ragas import AnswerCorrectness
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
|
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.answer_correctness import (
|
||||||
|
answer_correctness_fn_def,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.factuality import (
|
||||||
|
factuality_fn_def,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||||
|
aggregate_average,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import BraintrustScoringConfig
|
from .config import BraintrustScoringConfig
|
||||||
|
|
||||||
from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
|
|
||||||
|
|
||||||
|
|
||||||
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -28,26 +39,31 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets_api
|
self.datasets_api = datasets_api
|
||||||
self.braintrust_scoring_fn_impl = None
|
|
||||||
self.supported_fn_ids = {}
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
self.braintrust_evaluators = {
|
||||||
self.braintrust_scoring_fn_impl = BraintrustScoringFn()
|
"braintrust::factuality": Factuality(),
|
||||||
self.supported_fn_ids = {
|
"braintrust::answer-correctness": AnswerCorrectness(),
|
||||||
x.identifier
|
|
||||||
for x in self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs()
|
|
||||||
}
|
}
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
factuality_fn_def.identifier: factuality_fn_def,
|
||||||
|
answer_correctness_fn_def.identifier: answer_correctness_fn_def,
|
||||||
|
}
|
||||||
|
|
||||||
|
# self.braintrust_scoring_fn_impl = None
|
||||||
|
# self.supported_fn_ids = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
# self.braintrust_scoring_fn_impl = BraintrustScoringFn()
|
||||||
|
# self.supported_fn_ids = {
|
||||||
|
# x.identifier
|
||||||
|
# for x in self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs()
|
||||||
|
# }
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||||
assert (
|
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||||
self.braintrust_scoring_fn_impl is not None
|
|
||||||
), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. "
|
|
||||||
scoring_fn_defs_list = (
|
|
||||||
self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs()
|
|
||||||
)
|
|
||||||
|
|
||||||
for f in scoring_fn_defs_list:
|
for f in scoring_fn_defs_list:
|
||||||
assert f.identifier.startswith(
|
assert f.identifier.startswith(
|
||||||
"braintrust"
|
"braintrust"
|
||||||
|
@ -100,22 +116,37 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
results=res.results,
|
results=res.results,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||||
|
expected_answer = input_row["expected_answer"]
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
input_query = input_row["input_query"]
|
||||||
|
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
|
||||||
|
|
||||||
|
result = evaluator(generated_answer, expected_answer, input=input_query)
|
||||||
|
score = result.score
|
||||||
|
return {"score": score, "metadata": result.metadata}
|
||||||
|
|
||||||
async def score(
|
async def score(
|
||||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
assert (
|
# assert (
|
||||||
self.braintrust_scoring_fn_impl is not None
|
# self.braintrust_scoring_fn_impl is not None
|
||||||
), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. "
|
# ), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. "
|
||||||
|
|
||||||
res = {}
|
res = {}
|
||||||
for scoring_fn_id in scoring_functions:
|
for scoring_fn_id in scoring_functions:
|
||||||
if scoring_fn_id not in self.supported_fn_ids:
|
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
|
|
||||||
score_results = await self.braintrust_scoring_fn_impl.score(
|
score_results = [
|
||||||
input_rows, scoring_fn_id
|
await self.score_row(input_row, scoring_fn_id)
|
||||||
)
|
for input_row in input_rows
|
||||||
agg_results = await self.braintrust_scoring_fn_impl.aggregate(score_results)
|
]
|
||||||
|
|
||||||
|
agg_results = aggregate_average(score_results)
|
||||||
res[scoring_fn_id] = ScoringResult(
|
res[scoring_fn_id] = ScoringResult(
|
||||||
score_rows=score_results,
|
score_rows=score_results,
|
||||||
aggregated_results=agg_results,
|
aggregated_results=agg_results,
|
||||||
|
|
|
@ -1,68 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
# TODO: move the common base out from meta-reference into common
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
|
||||||
BaseScoringFn,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
|
||||||
aggregate_average,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.scoring import * # noqa: F403
|
|
||||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
|
||||||
from autoevals.llm import Factuality
|
|
||||||
from autoevals.ragas import AnswerCorrectness
|
|
||||||
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.answer_correctness import (
|
|
||||||
answer_correctness_fn_def,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.impls.braintrust.scoring.scoring_fn.fn_defs.factuality import (
|
|
||||||
factuality_fn_def,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
BRAINTRUST_FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
|
|
||||||
|
|
||||||
|
|
||||||
class BraintrustScoringFn(BaseScoringFn):
|
|
||||||
"""
|
|
||||||
Test whether an output is factual, compared to an original (`expected`) value.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.braintrust_evaluators = {
|
|
||||||
"braintrust::factuality": Factuality(),
|
|
||||||
"braintrust::answer-correctness": AnswerCorrectness(),
|
|
||||||
}
|
|
||||||
self.supported_fn_defs_registry = {
|
|
||||||
factuality_fn_def.identifier: factuality_fn_def,
|
|
||||||
answer_correctness_fn_def.identifier: answer_correctness_fn_def,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def score_row(
|
|
||||||
self,
|
|
||||||
input_row: Dict[str, Any],
|
|
||||||
scoring_fn_identifier: Optional[str] = None,
|
|
||||||
) -> ScoringResultRow:
|
|
||||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
|
||||||
expected_answer = input_row["expected_answer"]
|
|
||||||
generated_answer = input_row["generated_answer"]
|
|
||||||
input_query = input_row["input_query"]
|
|
||||||
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
|
|
||||||
|
|
||||||
result = evaluator(generated_answer, expected_answer, input=input_query)
|
|
||||||
score = result.score
|
|
||||||
return {"score": score, "metadata": result.metadata}
|
|
||||||
|
|
||||||
async def aggregate(
|
|
||||||
self, scoring_results: List[ScoringResultRow]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
return aggregate_average(scoring_results)
|
|
Loading…
Add table
Add a link
Reference in a new issue