mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
scoring fn braintrust fixture
This commit is contained in:
parent
ca2cd71182
commit
258e01ec67
3 changed files with 31 additions and 2 deletions
|
@ -31,6 +31,15 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
id="meta_reference_scoring_together_inference",
|
||||
marks=pytest.mark.meta_reference_scoring_together_inference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"scoring": "braintrust",
|
||||
"datasetio": "localfs",
|
||||
"inference": "together",
|
||||
},
|
||||
id="braintrust_scoring_together_inference",
|
||||
marks=pytest.mark.braintrust_scoring_together_inference,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
@ -38,6 +47,7 @@ def pytest_configure(config):
|
|||
for fixture_name in [
|
||||
"meta_reference_scoring_fireworks_inference",
|
||||
"meta_reference_scoring_together_inference",
|
||||
"braintrust_scoring_together_inference",
|
||||
]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
|
|
|
@ -31,7 +31,20 @@ def scoring_meta_reference() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
SCORING_FIXTURES = ["meta_reference", "remote"]
|
||||
@pytest.fixture(scope="session")
|
||||
def scoring_braintrust() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="braintrust",
|
||||
provider_type="braintrust",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
|
@ -60,8 +60,9 @@ class TestScoring:
|
|||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
|
||||
scoring_functions = {
|
||||
"meta-reference::equality": None,
|
||||
scoring_fns_list[0].identifier: None,
|
||||
}
|
||||
|
||||
response = await scoring_impl.score(
|
||||
|
@ -108,6 +109,11 @@ class TestScoring:
|
|||
provider_id="",
|
||||
)
|
||||
|
||||
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
|
||||
provider_id = scoring_fns_list[0].provider_id
|
||||
if provider_id == "braintrust":
|
||||
pytest.skip("Braintrust provider does not support scoring with params")
|
||||
|
||||
# scoring individual rows
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue