scoring fn braintrust fixture

This commit is contained in:
Xi Yan 2024-11-11 16:42:16 -05:00
parent ca2cd71182
commit 258e01ec67
3 changed files with 31 additions and 2 deletions

View file

@ -31,6 +31,15 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="meta_reference_scoring_together_inference",
marks=pytest.mark.meta_reference_scoring_together_inference,
),
pytest.param(
{
"scoring": "braintrust",
"datasetio": "localfs",
"inference": "together",
},
id="braintrust_scoring_together_inference",
marks=pytest.mark.braintrust_scoring_together_inference,
),
]
@ -38,6 +47,7 @@ def pytest_configure(config):
for fixture_name in [
"meta_reference_scoring_fireworks_inference",
"meta_reference_scoring_together_inference",
"braintrust_scoring_together_inference",
]:
config.addinivalue_line(
"markers",

View file

@ -31,7 +31,20 @@ def scoring_meta_reference() -> ProviderFixture:
)
SCORING_FIXTURES = ["meta_reference", "remote"]
@pytest.fixture(scope="session")
def scoring_braintrust() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="braintrust",
provider_type="braintrust",
config={},
)
],
)
SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
@pytest_asyncio.fixture(scope="session")

View file

@ -60,8 +60,9 @@ class TestScoring:
)
assert len(rows.rows) == 3
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
scoring_functions = {
"meta-reference::equality": None,
scoring_fns_list[0].identifier: None,
}
response = await scoring_impl.score(
@ -108,6 +109,11 @@ class TestScoring:
provider_id="",
)
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
if provider_id == "braintrust":
pytest.skip("Braintrust provider does not support scoring with params")
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",