fix tests after registration migration & rename meta-reference -> basic / llm_as_judge provider (#424)

* rename meta-reference -> basic

* config rename

* impl rename

* rename llm_as_judge, fix test

* util

* rebase

* naming fix
This commit is contained in:
Xi Yan 2024-11-12 10:35:44 -05:00 committed by GitHub
parent 3d7561e55c
commit 84c6fbbd93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 268 additions and 73 deletions

View file

@ -15,21 +15,12 @@ from .fixtures import SCORING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"scoring": "meta_reference",
"datasetio": "localfs",
"inference": "fireworks",
},
id="meta_reference_scoring_fireworks_inference",
marks=pytest.mark.meta_reference_scoring_fireworks_inference,
),
pytest.param(
{
"scoring": "meta_reference",
"scoring": "basic",
"datasetio": "localfs",
"inference": "together",
},
id="meta_reference_scoring_together_inference",
marks=pytest.mark.meta_reference_scoring_together_inference,
id="basic_scoring_together_inference",
marks=pytest.mark.basic_scoring_together_inference,
),
pytest.param(
{
@ -40,13 +31,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="braintrust_scoring_together_inference",
marks=pytest.mark.braintrust_scoring_together_inference,
),
pytest.param(
{
"scoring": "llm_as_judge",
"datasetio": "localfs",
"inference": "together",
},
id="llm_as_judge_scoring_together_inference",
marks=pytest.mark.llm_as_judge_scoring_together_inference,
),
]
def pytest_configure(config):
for fixture_name in [
"meta_reference_scoring_fireworks_inference",
"meta_reference_scoring_together_inference",
"basic_scoring_together_inference",
"braintrust_scoring_together_inference",
]:
config.addinivalue_line(

View file

@ -19,12 +19,12 @@ def scoring_remote() -> ProviderFixture:
@pytest.fixture(scope="session")
def scoring_meta_reference() -> ProviderFixture:
def scoring_basic() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
provider_id="basic",
provider_type="inline::basic",
config={},
)
],
@ -37,14 +37,27 @@ def scoring_braintrust() -> ProviderFixture:
providers=[
Provider(
provider_id="braintrust",
provider_type="braintrust",
provider_type="inline::braintrust",
config={},
)
],
)
SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
@pytest.fixture(scope="session")
def scoring_llm_as_judge() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="llm-as-judge",
provider_type="inline::llm-as-judge",
config={},
)
],
)
SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"]
@pytest_asyncio.fixture(scope="session")

View file

@ -43,6 +43,13 @@ class TestScoring:
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
if provider_id == "llm-as-judge":
pytest.skip(
f"{provider_id} provider does not support scoring without params"
)
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
@ -111,8 +118,8 @@ class TestScoring:
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
if provider_id == "braintrust":
pytest.skip("Braintrust provider does not support scoring with params")
if provider_id == "braintrust" or provider_id == "basic":
pytest.skip(f"{provider_id} provider does not support scoring with params")
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
@ -122,7 +129,7 @@ class TestScoring:
assert len(rows.rows) == 3
scoring_functions = {
"meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams(
"llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
judge_score_regexes=[r"Score: (\d+)"],