mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 15:49:40 +00:00
braintrust provider
This commit is contained in:
parent
68346fac39
commit
38186f7903
6 changed files with 116 additions and 49 deletions
|
@ -11,12 +11,12 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import * # noqa: F403
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from autoevals.llm import Factuality
|
|
||||||
from autoevals.ragas import AnswerCorrectness
|
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
|
|
||||||
from .config import BraintrustScoringConfig
|
from .config import BraintrustScoringConfig
|
||||||
|
|
||||||
|
from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
|
||||||
|
|
||||||
|
|
||||||
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -28,36 +28,29 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets_api
|
self.datasets_api = datasets_api
|
||||||
self.scoring_fn_id_impls = {}
|
self.braintrust_scoring_fn_impl = None
|
||||||
|
self.supported_fn_ids = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.scoring_fn_id_impls = {
|
self.braintrust_scoring_fn_impl = BraintrustScoringFn()
|
||||||
"braintrust::factuality": Factuality(),
|
await self.braintrust_scoring_fn_impl.initialize()
|
||||||
"braintrust::answer-correctness": AnswerCorrectness(),
|
self.supported_fn_ids = {
|
||||||
|
x.identifier
|
||||||
|
for x in self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs()
|
||||||
}
|
}
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||||
return [
|
assert (
|
||||||
ScoringFnDef(
|
self.braintrust_scoring_fn_impl is not None
|
||||||
identifier="braintrust::factuality",
|
), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. "
|
||||||
description="Test whether an output is factual, compared to an original (`expected`) value.",
|
return self.braintrust_scoring_fn_impl.get_supported_scoring_fn_defs()
|
||||||
parameters=[],
|
|
||||||
return_type=NumberType(),
|
|
||||||
),
|
|
||||||
ScoringFnDef(
|
|
||||||
identifier="braintrust::answer-correctness",
|
|
||||||
description="Test whether an output is factual, compared to an original (`expected`) value.",
|
|
||||||
parameters=[],
|
|
||||||
return_type=NumberType(),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||||
# self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
raise NotImplementedError(
|
||||||
# self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
|
"Registering scoring function not allowed for braintrust provider"
|
||||||
return None
|
)
|
||||||
|
|
||||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||||
|
@ -82,19 +75,18 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
print("score_batch")
|
|
||||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||||
# all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
# dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
# rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
# )
|
)
|
||||||
# res = await self.score(
|
res = await self.score(
|
||||||
# input_rows=all_rows.rows, scoring_functions=scoring_functions
|
input_rows=all_rows.rows, scoring_functions=scoring_functions
|
||||||
# )
|
)
|
||||||
# if save_results_dataset:
|
if save_results_dataset:
|
||||||
# # TODO: persist and register dataset on to server for reading
|
# TODO: persist and register dataset on to server for reading
|
||||||
# # self.datasets_api.register_dataset()
|
# self.datasets_api.register_dataset()
|
||||||
# raise NotImplementedError("Save results dataset not implemented yet")
|
raise NotImplementedError("Save results dataset not implemented yet")
|
||||||
|
|
||||||
return ScoreBatchResponse(
|
return ScoreBatchResponse(
|
||||||
results=res.results,
|
results=res.results,
|
||||||
|
@ -103,19 +95,25 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def score(
|
async def score(
|
||||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
|
assert (
|
||||||
|
self.braintrust_scoring_fn_impl is not None
|
||||||
|
), "braintrust_scoring_fn_impl is not initialized, need to call initialize for provider. "
|
||||||
|
|
||||||
res = {}
|
res = {}
|
||||||
print("score")
|
|
||||||
for scoring_fn_id in scoring_functions:
|
for scoring_fn_id in scoring_functions:
|
||||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
if scoring_fn_id not in self.supported_fn_ids:
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
|
|
||||||
|
# scoring_impl = self.scoring_fn_id_impls[scoring_fn_id]
|
||||||
# scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
# scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||||
# score_results = await scoring_fn.score(input_rows, scoring_fn_id)
|
score_results = await self.braintrust_scoring_fn_impl.score(
|
||||||
# agg_results = await scoring_fn.aggregate(score_results)
|
input_rows, scoring_fn_id
|
||||||
# res[scoring_fn_id] = ScoringResult(
|
)
|
||||||
# score_rows=score_results,
|
agg_results = await self.braintrust_scoring_fn_impl.aggregate(score_results)
|
||||||
# aggregated_results=agg_results,
|
res[scoring_fn_id] = ScoringResult(
|
||||||
# )
|
score_rows=score_results,
|
||||||
|
aggregated_results=agg_results,
|
||||||
|
)
|
||||||
|
|
||||||
return ScoreResponse(
|
return ScoreResponse(
|
||||||
results=res,
|
results=res,
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
# TODO: move the common base out from meta-reference into common
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import (
|
||||||
|
BaseScoringFn,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||||
|
aggregate_average,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F403
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from autoevals.llm import Factuality
|
||||||
|
from autoevals.ragas import AnswerCorrectness
|
||||||
|
|
||||||
|
|
||||||
|
BRAINTRUST_FN_DEFS_PATH = Path(__file__).parent / "fn_defs"
|
||||||
|
|
||||||
|
|
||||||
|
class BraintrustScoringFn(BaseScoringFn):
|
||||||
|
"""
|
||||||
|
Test whether an output is factual, compared to an original (`expected`) value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.braintrust_evaluators = {
|
||||||
|
"braintrust::factuality": Factuality(),
|
||||||
|
"braintrust::answer-correctness": AnswerCorrectness(),
|
||||||
|
}
|
||||||
|
self.defs_paths = [
|
||||||
|
str(x) for x in sorted(BRAINTRUST_FN_DEFS_PATH.glob("*.json"))
|
||||||
|
]
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||||
|
expected_answer = input_row["expected_answer"]
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
input_query = input_row["input_query"]
|
||||||
|
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
|
||||||
|
|
||||||
|
result = evaluator(generated_answer, expected_answer, input=input_query)
|
||||||
|
score = result.score
|
||||||
|
return {"score": score, "metadata": result.metadata}
|
||||||
|
|
||||||
|
async def aggregate(
|
||||||
|
self, scoring_results: List[ScoringResultRow]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return aggregate_average(scoring_results)
|
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"identifier": "braintrust::answer-correctness",
|
||||||
|
"description": "Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||||
|
"metadata": {},
|
||||||
|
"parameters": [],
|
||||||
|
"return_type": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"context": null
|
||||||
|
}
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"identifier": "braintrust::factuality",
|
"identifier": "braintrust::factuality",
|
||||||
"description": "Test whether an output is factual, compared to an original (`expected`) value. ",
|
"description": "Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"parameters": [],
|
"parameters": [],
|
||||||
"return_type": {
|
"return_type": {
|
||||||
|
|
|
@ -82,7 +82,7 @@ async def register_dataset(
|
||||||
|
|
||||||
dataset = DatasetDefWithProvider(
|
dataset = DatasetDefWithProvider(
|
||||||
identifier=dataset_id,
|
identifier=dataset_id,
|
||||||
provider_id=os.environ["PROVIDER_ID"] or os.environ["DATASETIO_PROVIDER_ID"],
|
provider_id=os.environ["DATASETIO_PROVIDER_ID"] or os.environ["PROVIDER_ID"],
|
||||||
url=URL(
|
url=URL(
|
||||||
uri=test_url,
|
uri=test_url,
|
||||||
),
|
),
|
||||||
|
|
|
@ -141,16 +141,14 @@ async def test_scoring_score(scoring_settings, provider_scoring_functions):
|
||||||
function_ids = [f.identifier for f in scoring_functions]
|
function_ids = [f.identifier for f in scoring_functions]
|
||||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||||
provider_type = provider.__provider_spec__.provider_type
|
provider_type = provider.__provider_spec__.provider_type
|
||||||
if provider_type not in ("meta-reference"):
|
|
||||||
pytest.skip(
|
|
||||||
"Other scoring providers don't support registering scoring functions."
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await scoring_impl.score_batch(
|
response = await scoring_impl.score_batch(
|
||||||
dataset_id=response[0].identifier,
|
dataset_id=response[0].identifier,
|
||||||
scoring_functions=list(provider_scoring_functions[provider_type]),
|
scoring_functions=list(provider_scoring_functions[provider_type]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("RESPONSE", response)
|
||||||
|
|
||||||
assert len(response.results) == len(provider_scoring_functions[provider_type])
|
assert len(response.results) == len(provider_scoring_functions[provider_type])
|
||||||
for x in provider_scoring_functions[provider_type]:
|
for x in provider_scoring_functions[provider_type]:
|
||||||
assert x in response.results
|
assert x in response.results
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue