mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 20:18:52 +00:00
[Evals API][7/n] braintrust scoring provider (#333)
* wip scoring refactor * llm as judge, move folders * test full generation + eval * extract score regex to llm context * remove prints, cleanup braintrust in this branch * braintrust skeleton * datasetio test fix * braintrust provider * remove prints * dependencies * change json -> class * json -> class * remove initialize * address nits * check identifier prefix * braintrust scoring identifier check, rebase * udpate MANIFEST * manifest * remove braintrust scoring_fn * remove comments * tests * imports fix
This commit is contained in:
parent
ae671eaf7a
commit
ed833bb758
11 changed files with 274 additions and 15 deletions
|
@ -7,6 +7,9 @@ providers:
|
|||
- provider_id: test-meta
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
- provider_id: test-braintrust
|
||||
provider_type: braintrust
|
||||
config: {}
|
||||
inference:
|
||||
- provider_id: tgi0
|
||||
provider_type: remote::tgi
|
||||
|
|
|
@ -43,16 +43,35 @@ async def scoring_settings():
|
|||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def provider_scoring_functions():
|
||||
return {
|
||||
"meta-reference": {
|
||||
"meta-reference::equality",
|
||||
"meta-reference::subset_of",
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
},
|
||||
"braintrust": {
|
||||
"braintrust::factuality",
|
||||
"braintrust::answer-correctness",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_list(scoring_settings):
|
||||
async def test_scoring_functions_list(scoring_settings, provider_scoring_functions):
|
||||
scoring_impl = scoring_settings["scoring_impl"]
|
||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||
assert isinstance(scoring_functions, list)
|
||||
assert len(scoring_functions) > 0
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
assert "meta-reference::equality" in function_ids
|
||||
assert "meta-reference::subset_of" in function_ids
|
||||
assert "meta-reference::llm_as_judge_8b_correctness" in function_ids
|
||||
# get current provider_type we're testing
|
||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||
provider_type = provider.__provider_spec__.provider_type
|
||||
|
||||
for x in provider_scoring_functions[provider_type]:
|
||||
assert x in function_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -60,6 +79,17 @@ async def test_scoring_functions_register(scoring_settings):
|
|||
scoring_impl = scoring_settings["scoring_impl"]
|
||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||
datasets_impl = scoring_settings["datasets_impl"]
|
||||
|
||||
# get current provider_type we're testing
|
||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||
provider_type = provider.__provider_spec__.provider_type
|
||||
if provider_type not in ("meta-reference"):
|
||||
pytest.skip(
|
||||
"Other scoring providers don't support registering scoring functions."
|
||||
)
|
||||
|
||||
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>"""
|
||||
# register the scoring function
|
||||
await scoring_functions_impl.register_scoring_function(
|
||||
|
@ -97,24 +127,26 @@ async def test_scoring_functions_register(scoring_settings):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score(scoring_settings):
|
||||
async def test_scoring_score(scoring_settings, provider_scoring_functions):
|
||||
scoring_impl = scoring_settings["scoring_impl"]
|
||||
datasets_impl = scoring_settings["datasets_impl"]
|
||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||
await register_dataset(datasets_impl)
|
||||
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
# get current provider_type we're testing
|
||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
|
||||
provider_type = provider.__provider_spec__.provider_type
|
||||
|
||||
response = await scoring_impl.score_batch(
|
||||
dataset_id=response[0].identifier,
|
||||
scoring_functions=[
|
||||
"meta-reference::equality",
|
||||
"meta-reference::subset_of",
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
],
|
||||
scoring_functions=list(provider_scoring_functions[provider_type]),
|
||||
)
|
||||
|
||||
assert len(response.results) == 3
|
||||
assert "meta-reference::equality" in response.results
|
||||
assert "meta-reference::subset_of" in response.results
|
||||
assert "meta-reference::llm_as_judge_8b_correctness" in response.results
|
||||
assert len(response.results) == len(provider_scoring_functions[provider_type])
|
||||
for x in provider_scoring_functions[provider_type]:
|
||||
assert x in response.results
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue