[Evals API][7/n] braintrust scoring provider (#333)

* wip scoring refactor

* llm as judge, move folders

* test full generation + eval

* extract score regex to llm context

* remove prints, cleanup braintrust in this branch

* braintrust skeleton

* datasetio test fix

* braintrust provider

* remove prints

* dependencies

* change json -> class

* json -> class

* remove initialize

* address nits

* check identifier prefix

* braintrust scoring identifier check, rebase

* udpate MANIFEST

* manifest

* remove braintrust scoring_fn

* remove comments

* tests

* imports fix
This commit is contained in:
Xi Yan 2024-10-28 18:59:35 -07:00 committed by GitHub
parent ae671eaf7a
commit ed833bb758
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 274 additions and 15 deletions

View file

@ -7,6 +7,9 @@ providers:
- provider_id: test-meta
provider_type: meta-reference
config: {}
- provider_id: test-braintrust
provider_type: braintrust
config: {}
inference:
- provider_id: tgi0
provider_type: remote::tgi

View file

@ -43,16 +43,35 @@ async def scoring_settings():
}
@pytest_asyncio.fixture(scope="session")
async def provider_scoring_functions():
return {
"meta-reference": {
"meta-reference::equality",
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
},
"braintrust": {
"braintrust::factuality",
"braintrust::answer-correctness",
},
}
@pytest.mark.asyncio
async def test_scoring_functions_list(scoring_settings):
async def test_scoring_functions_list(scoring_settings, provider_scoring_functions):
scoring_impl = scoring_settings["scoring_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
scoring_functions = await scoring_functions_impl.list_scoring_functions()
assert isinstance(scoring_functions, list)
assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions]
assert "meta-reference::equality" in function_ids
assert "meta-reference::subset_of" in function_ids
assert "meta-reference::llm_as_judge_8b_correctness" in function_ids
# get current provider_type we're testing
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
provider_type = provider.__provider_spec__.provider_type
for x in provider_scoring_functions[provider_type]:
assert x in function_ids
@pytest.mark.asyncio
@ -60,6 +79,17 @@ async def test_scoring_functions_register(scoring_settings):
scoring_impl = scoring_settings["scoring_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
datasets_impl = scoring_settings["datasets_impl"]
# get current provider_type we're testing
scoring_functions = await scoring_functions_impl.list_scoring_functions()
function_ids = [f.identifier for f in scoring_functions]
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
provider_type = provider.__provider_spec__.provider_type
if provider_type not in ("meta-reference"):
pytest.skip(
"Other scoring providers don't support registering scoring functions."
)
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>"""
# register the scoring function
await scoring_functions_impl.register_scoring_function(
@ -97,24 +127,26 @@ async def test_scoring_functions_register(scoring_settings):
@pytest.mark.asyncio
async def test_scoring_score(scoring_settings):
async def test_scoring_score(scoring_settings, provider_scoring_functions):
scoring_impl = scoring_settings["scoring_impl"]
datasets_impl = scoring_settings["datasets_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
# get current provider_type we're testing
scoring_functions = await scoring_functions_impl.list_scoring_functions()
function_ids = [f.identifier for f in scoring_functions]
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
provider_type = provider.__provider_spec__.provider_type
response = await scoring_impl.score_batch(
dataset_id=response[0].identifier,
scoring_functions=[
"meta-reference::equality",
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
],
scoring_functions=list(provider_scoring_functions[provider_type]),
)
assert len(response.results) == 3
assert "meta-reference::equality" in response.results
assert "meta-reference::subset_of" in response.results
assert "meta-reference::llm_as_judge_8b_correctness" in response.results
assert len(response.results) == len(provider_scoring_functions[provider_type])
for x in provider_scoring_functions[provider_type]:
assert x in response.results