mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 00:05:18 +00:00
scoring fn braintrust fixture
This commit is contained in:
parent
ca2cd71182
commit
258e01ec67
3 changed files with 31 additions and 2 deletions
|
@ -31,6 +31,15 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
id="meta_reference_scoring_together_inference",
|
id="meta_reference_scoring_together_inference",
|
||||||
marks=pytest.mark.meta_reference_scoring_together_inference,
|
marks=pytest.mark.meta_reference_scoring_together_inference,
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"scoring": "braintrust",
|
||||||
|
"datasetio": "localfs",
|
||||||
|
"inference": "together",
|
||||||
|
},
|
||||||
|
id="braintrust_scoring_together_inference",
|
||||||
|
marks=pytest.mark.braintrust_scoring_together_inference,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,6 +47,7 @@ def pytest_configure(config):
|
||||||
for fixture_name in [
|
for fixture_name in [
|
||||||
"meta_reference_scoring_fireworks_inference",
|
"meta_reference_scoring_fireworks_inference",
|
||||||
"meta_reference_scoring_together_inference",
|
"meta_reference_scoring_together_inference",
|
||||||
|
"braintrust_scoring_together_inference",
|
||||||
]:
|
]:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers",
|
"markers",
|
||||||
|
|
|
@ -31,7 +31,20 @@ def scoring_meta_reference() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SCORING_FIXTURES = ["meta_reference", "remote"]
|
@pytest.fixture(scope="session")
|
||||||
|
def scoring_braintrust() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="braintrust",
|
||||||
|
provider_type="braintrust",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
|
|
@ -60,8 +60,9 @@ class TestScoring:
|
||||||
)
|
)
|
||||||
assert len(rows.rows) == 3
|
assert len(rows.rows) == 3
|
||||||
|
|
||||||
|
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
|
||||||
scoring_functions = {
|
scoring_functions = {
|
||||||
"meta-reference::equality": None,
|
scoring_fns_list[0].identifier: None,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await scoring_impl.score(
|
response = await scoring_impl.score(
|
||||||
|
@ -108,6 +109,11 @@ class TestScoring:
|
||||||
provider_id="",
|
provider_id="",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
|
||||||
|
provider_id = scoring_fns_list[0].provider_id
|
||||||
|
if provider_id == "braintrust":
|
||||||
|
pytest.skip("Braintrust provider does not support scoring with params")
|
||||||
|
|
||||||
# scoring individual rows
|
# scoring individual rows
|
||||||
rows = await datasetio_impl.get_rows_paginated(
|
rows = await datasetio_impl.get_rows_paginated(
|
||||||
dataset_id="test_dataset",
|
dataset_id="test_dataset",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue