From 258e01ec6775979ebf7305d69af6d5c4168d8f2c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 11 Nov 2024 16:42:16 -0500 Subject: [PATCH] scoring fn braintrust fixture --- llama_stack/providers/tests/scoring/conftest.py | 10 ++++++++++ llama_stack/providers/tests/scoring/fixtures.py | 15 ++++++++++++++- .../providers/tests/scoring/test_scoring.py | 8 +++++++- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py index 513180ef4..ed56df230 100644 --- a/llama_stack/providers/tests/scoring/conftest.py +++ b/llama_stack/providers/tests/scoring/conftest.py @@ -31,6 +31,15 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="meta_reference_scoring_together_inference", marks=pytest.mark.meta_reference_scoring_together_inference, ), + pytest.param( + { + "scoring": "braintrust", + "datasetio": "localfs", + "inference": "together", + }, + id="braintrust_scoring_together_inference", + marks=pytest.mark.braintrust_scoring_together_inference, + ), ] @@ -38,6 +47,7 @@ def pytest_configure(config): for fixture_name in [ "meta_reference_scoring_fireworks_inference", "meta_reference_scoring_together_inference", + "braintrust_scoring_together_inference", ]: config.addinivalue_line( "markers", diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index 96409d200..648d35859 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -31,7 +31,20 @@ def scoring_meta_reference() -> ProviderFixture: ) -SCORING_FIXTURES = ["meta_reference", "remote"] +@pytest.fixture(scope="session") +def scoring_braintrust() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="braintrust", + provider_type="braintrust", + config={}, + ) + ], + ) + + +SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 1f2608f3b..f3c925048 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -60,8 +60,9 @@ class TestScoring: ) assert len(rows.rows) == 3 + scoring_fns_list = await scoring_functions_impl.list_scoring_functions() scoring_functions = { - "meta-reference::equality": None, + scoring_fns_list[0].identifier: None, } response = await scoring_impl.score( @@ -108,6 +109,11 @@ class TestScoring: provider_id="", ) + scoring_fns_list = await scoring_functions_impl.list_scoring_functions() + provider_id = scoring_fns_list[0].provider_id + if provider_id == "braintrust": + pytest.skip("Braintrust provider does not support scoring with params") + # scoring individual rows rows = await datasetio_impl.get_rows_paginated( dataset_id="test_dataset",