mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 19:34:19 +00:00
[Evals API][7/n] braintrust scoring provider (#333)
* wip scoring refactor * llm as judge, move folders * test full generation + eval * extract score regex to llm context * remove prints, cleanup braintrust in this branch * braintrust skeleton * datasetio test fix * braintrust provider * remove prints * dependencies * change json -> class * json -> class * remove initialize * address nits * check identifier prefix * braintrust scoring identifier check, rebase * udpate MANIFEST * manifest * remove braintrust scoring_fn * remove comments * tests * imports fix
This commit is contained in:
parent
ae671eaf7a
commit
ed833bb758
11 changed files with 274 additions and 15 deletions
21
llama_stack/providers/impls/braintrust/scoring/__init__.py
Normal file
21
llama_stack/providers/impls/braintrust/scoring/__init__.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
from .config import BraintrustScoringConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: BraintrustScoringConfig,
|
||||||
|
deps: Dict[Api, ProviderSpec],
|
||||||
|
):
|
||||||
|
from .braintrust import BraintrustScoringImpl
|
||||||
|
|
||||||
|
impl = BraintrustScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
140
llama_stack/providers/impls/braintrust/scoring/braintrust.py
Normal file
140
llama_stack/providers/impls/braintrust/scoring/braintrust.py
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F403
|
||||||
|
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
|
||||||
|
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
|
||||||
|
from autoevals.llm import Factuality
|
||||||
|
from autoevals.ragas import AnswerCorrectness
|
||||||
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
|
from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import (
|
||||||
|
aggregate_average,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import BraintrustScoringConfig
|
||||||
|
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
||||||
|
from .scoring_fn.fn_defs.factuality import factuality_fn_def
|
||||||
|
|
||||||
|
|
||||||
|
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: BraintrustScoringConfig,
|
||||||
|
datasetio_api: DatasetIO,
|
||||||
|
datasets_api: Datasets,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.datasetio_api = datasetio_api
|
||||||
|
self.datasets_api = datasets_api
|
||||||
|
|
||||||
|
self.braintrust_evaluators = {
|
||||||
|
"braintrust::factuality": Factuality(),
|
||||||
|
"braintrust::answer-correctness": AnswerCorrectness(),
|
||||||
|
}
|
||||||
|
self.supported_fn_defs_registry = {
|
||||||
|
factuality_fn_def.identifier: factuality_fn_def,
|
||||||
|
answer_correctness_fn_def.identifier: answer_correctness_fn_def,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def list_scoring_functions(self) -> List[ScoringFnDef]:
|
||||||
|
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||||
|
for f in scoring_fn_defs_list:
|
||||||
|
assert f.identifier.startswith(
|
||||||
|
"braintrust"
|
||||||
|
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||||
|
|
||||||
|
return scoring_fn_defs_list
|
||||||
|
|
||||||
|
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Registering scoring function not allowed for braintrust provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||||
|
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||||
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
||||||
|
)
|
||||||
|
|
||||||
|
for required_column in ["generated_answer", "expected_answer", "input_query"]:
|
||||||
|
if required_column not in dataset_def.dataset_schema:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {dataset_id} does not have a '{required_column}' column."
|
||||||
|
)
|
||||||
|
if dataset_def.dataset_schema[required_column].type != "string":
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: List[str],
|
||||||
|
save_results_dataset: bool = False,
|
||||||
|
) -> ScoreBatchResponse:
|
||||||
|
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||||
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
rows_in_page=-1,
|
||||||
|
)
|
||||||
|
res = await self.score(
|
||||||
|
input_rows=all_rows.rows, scoring_functions=scoring_functions
|
||||||
|
)
|
||||||
|
if save_results_dataset:
|
||||||
|
# TODO: persist and register dataset on to server for reading
|
||||||
|
# self.datasets_api.register_dataset()
|
||||||
|
raise NotImplementedError("Save results dataset not implemented yet")
|
||||||
|
|
||||||
|
return ScoreBatchResponse(
|
||||||
|
results=res.results,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def score_row(
|
||||||
|
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||||
|
expected_answer = input_row["expected_answer"]
|
||||||
|
generated_answer = input_row["generated_answer"]
|
||||||
|
input_query = input_row["input_query"]
|
||||||
|
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
|
||||||
|
|
||||||
|
result = evaluator(generated_answer, expected_answer, input=input_query)
|
||||||
|
score = result.score
|
||||||
|
return {"score": score, "metadata": result.metadata}
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||||
|
) -> ScoreResponse:
|
||||||
|
res = {}
|
||||||
|
for scoring_fn_id in scoring_functions:
|
||||||
|
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||||
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
|
|
||||||
|
score_results = [
|
||||||
|
await self.score_row(input_row, scoring_fn_id)
|
||||||
|
for input_row in input_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
agg_results = aggregate_average(score_results)
|
||||||
|
res[scoring_fn_id] = ScoringResult(
|
||||||
|
score_rows=score_results,
|
||||||
|
aggregated_results=agg_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ScoreResponse(
|
||||||
|
results=res,
|
||||||
|
)
|
9
llama_stack/providers/impls/braintrust/scoring/config.py
Normal file
9
llama_stack/providers/impls/braintrust/scoring/config.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||||
|
|
||||||
|
|
||||||
|
class BraintrustScoringConfig(BaseModel): ...
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||||
|
|
||||||
|
|
||||||
|
answer_correctness_fn_def = ScoringFnDef(
|
||||||
|
identifier="braintrust::answer-correctness",
|
||||||
|
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||||
|
parameters=[],
|
||||||
|
return_type=NumberType(),
|
||||||
|
)
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||||
|
|
||||||
|
|
||||||
|
factuality_fn_def = ScoringFnDef(
|
||||||
|
identifier="braintrust::factuality",
|
||||||
|
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||||
|
parameters=[],
|
||||||
|
return_type=NumberType(),
|
||||||
|
)
|
|
@ -23,4 +23,15 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.inference,
|
Api.inference,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.scoring,
|
||||||
|
provider_type="braintrust",
|
||||||
|
pip_packages=["autoevals", "openai"],
|
||||||
|
module="llama_stack.providers.impls.braintrust.scoring",
|
||||||
|
config_class="llama_stack.providers.impls.braintrust.scoring.BraintrustScoringConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.datasetio,
|
||||||
|
Api.datasets,
|
||||||
|
],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -82,7 +82,8 @@ async def register_dataset(
|
||||||
|
|
||||||
dataset = DatasetDefWithProvider(
|
dataset = DatasetDefWithProvider(
|
||||||
identifier=dataset_id,
|
identifier=dataset_id,
|
||||||
provider_id=os.environ["PROVIDER_ID"],
|
provider_id=os.environ.get("DATASETIO_PROVIDER_ID", None)
|
||||||
|
or os.environ["PROVIDER_ID"],
|
||||||
url=URL(
|
url=URL(
|
||||||
uri=test_url,
|
uri=test_url,
|
||||||
),
|
),
|
||||||
|
|
|
@ -7,6 +7,9 @@ providers:
|
||||||
- provider_id: test-meta
|
- provider_id: test-meta
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
|
- provider_id: test-braintrust
|
||||||
|
provider_type: braintrust
|
||||||
|
config: {}
|
||||||
inference:
|
inference:
|
||||||
- provider_id: tgi0
|
- provider_id: tgi0
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
|
|
|
@ -43,16 +43,35 @@ async def scoring_settings():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def provider_scoring_functions():
|
||||||
|
return {
|
||||||
|
"meta-reference": {
|
||||||
|
"meta-reference::equality",
|
||||||
|
"meta-reference::subset_of",
|
||||||
|
"meta-reference::llm_as_judge_8b_correctness",
|
||||||
|
},
|
||||||
|
"braintrust": {
|
||||||
|
"braintrust::factuality",
|
||||||
|
"braintrust::answer-correctness",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_scoring_functions_list(scoring_settings):
|
async def test_scoring_functions_list(scoring_settings, provider_scoring_functions):
|
||||||
|
scoring_impl = scoring_settings["scoring_impl"]
|
||||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||||
assert isinstance(scoring_functions, list)
|
assert isinstance(scoring_functions, list)
|
||||||
assert len(scoring_functions) > 0
|
assert len(scoring_functions) > 0
|
||||||
function_ids = [f.identifier for f in scoring_functions]
|
function_ids = [f.identifier for f in scoring_functions]
|
||||||
assert "meta-reference::equality" in function_ids
|
# get current provider_type we're testing
|
||||||
assert "meta-reference::subset_of" in function_ids
|
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||||
assert "meta-reference::llm_as_judge_8b_correctness" in function_ids
|
provider_type = provider.__provider_spec__.provider_type
|
||||||
|
|
||||||
|
for x in provider_scoring_functions[provider_type]:
|
||||||
|
assert x in function_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -60,6 +79,17 @@ async def test_scoring_functions_register(scoring_settings):
|
||||||
scoring_impl = scoring_settings["scoring_impl"]
|
scoring_impl = scoring_settings["scoring_impl"]
|
||||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||||
datasets_impl = scoring_settings["datasets_impl"]
|
datasets_impl = scoring_settings["datasets_impl"]
|
||||||
|
|
||||||
|
# get current provider_type we're testing
|
||||||
|
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||||
|
function_ids = [f.identifier for f in scoring_functions]
|
||||||
|
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||||
|
provider_type = provider.__provider_spec__.provider_type
|
||||||
|
if provider_type not in ("meta-reference"):
|
||||||
|
pytest.skip(
|
||||||
|
"Other scoring providers don't support registering scoring functions."
|
||||||
|
)
|
||||||
|
|
||||||
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>"""
|
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>"""
|
||||||
# register the scoring function
|
# register the scoring function
|
||||||
await scoring_functions_impl.register_scoring_function(
|
await scoring_functions_impl.register_scoring_function(
|
||||||
|
@ -97,24 +127,26 @@ async def test_scoring_functions_register(scoring_settings):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_scoring_score(scoring_settings):
|
async def test_scoring_score(scoring_settings, provider_scoring_functions):
|
||||||
scoring_impl = scoring_settings["scoring_impl"]
|
scoring_impl = scoring_settings["scoring_impl"]
|
||||||
datasets_impl = scoring_settings["datasets_impl"]
|
datasets_impl = scoring_settings["datasets_impl"]
|
||||||
|
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||||
await register_dataset(datasets_impl)
|
await register_dataset(datasets_impl)
|
||||||
|
|
||||||
response = await datasets_impl.list_datasets()
|
response = await datasets_impl.list_datasets()
|
||||||
assert len(response) == 1
|
assert len(response) == 1
|
||||||
|
|
||||||
|
# get current provider_type we're testing
|
||||||
|
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||||
|
function_ids = [f.identifier for f in scoring_functions]
|
||||||
|
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||||
|
provider_type = provider.__provider_spec__.provider_type
|
||||||
|
|
||||||
response = await scoring_impl.score_batch(
|
response = await scoring_impl.score_batch(
|
||||||
dataset_id=response[0].identifier,
|
dataset_id=response[0].identifier,
|
||||||
scoring_functions=[
|
scoring_functions=list(provider_scoring_functions[provider_type]),
|
||||||
"meta-reference::equality",
|
|
||||||
"meta-reference::subset_of",
|
|
||||||
"meta-reference::llm_as_judge_8b_correctness",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(response.results) == 3
|
assert len(response.results) == len(provider_scoring_functions[provider_type])
|
||||||
assert "meta-reference::equality" in response.results
|
for x in provider_scoring_functions[provider_type]:
|
||||||
assert "meta-reference::subset_of" in response.results
|
assert x in response.results
|
||||||
assert "meta-reference::llm_as_judge_8b_correctness" in response.results
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue