scoring fn braintrust fixture

This commit is contained in:
Xi Yan 2024-11-11 16:42:16 -05:00
parent ca2cd71182
commit 258e01ec67
3 changed files with 31 additions and 2 deletions

View file

@ -31,6 +31,15 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="meta_reference_scoring_together_inference", id="meta_reference_scoring_together_inference",
marks=pytest.mark.meta_reference_scoring_together_inference, marks=pytest.mark.meta_reference_scoring_together_inference,
), ),
pytest.param(
{
"scoring": "braintrust",
"datasetio": "localfs",
"inference": "together",
},
id="braintrust_scoring_together_inference",
marks=pytest.mark.braintrust_scoring_together_inference,
),
] ]
@ -38,6 +47,7 @@ def pytest_configure(config):
for fixture_name in [ for fixture_name in [
"meta_reference_scoring_fireworks_inference", "meta_reference_scoring_fireworks_inference",
"meta_reference_scoring_together_inference", "meta_reference_scoring_together_inference",
"braintrust_scoring_together_inference",
]: ]:
config.addinivalue_line( config.addinivalue_line(
"markers", "markers",

View file

@ -31,7 +31,20 @@ def scoring_meta_reference() -> ProviderFixture:
) )
SCORING_FIXTURES = ["meta_reference", "remote"] @pytest.fixture(scope="session")
def scoring_braintrust() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="braintrust",
provider_type="braintrust",
config={},
)
],
)
SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")

View file

@ -60,8 +60,9 @@ class TestScoring:
) )
assert len(rows.rows) == 3 assert len(rows.rows) == 3
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
scoring_functions = { scoring_functions = {
"meta-reference::equality": None, scoring_fns_list[0].identifier: None,
} }
response = await scoring_impl.score( response = await scoring_impl.score(
@ -108,6 +109,11 @@ class TestScoring:
provider_id="", provider_id="",
) )
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
if provider_id == "braintrust":
pytest.skip("Braintrust provider does not support scoring with params")
# scoring individual rows # scoring individual rows
rows = await datasetio_impl.get_rows_paginated( rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset", dataset_id="test_dataset",