forked from phoenix-oss/llama-stack-mirror
fix tests after registration migration & rename meta-reference -> basic / llm_as_judge provider (#424)
* rename meta-reference -> basic * config rename * impl rename * rename llm_as_judge, fix test * util * rebase * naming fix
This commit is contained in:
parent
3d7561e55c
commit
84c6fbbd93
24 changed files with 268 additions and 73 deletions
|
@ -15,21 +15,12 @@ from .fixtures import SCORING_FIXTURES
|
|||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"scoring": "meta_reference",
|
||||
"datasetio": "localfs",
|
||||
"inference": "fireworks",
|
||||
},
|
||||
id="meta_reference_scoring_fireworks_inference",
|
||||
marks=pytest.mark.meta_reference_scoring_fireworks_inference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"scoring": "meta_reference",
|
||||
"scoring": "basic",
|
||||
"datasetio": "localfs",
|
||||
"inference": "together",
|
||||
},
|
||||
id="meta_reference_scoring_together_inference",
|
||||
marks=pytest.mark.meta_reference_scoring_together_inference,
|
||||
id="basic_scoring_together_inference",
|
||||
marks=pytest.mark.basic_scoring_together_inference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
|
@ -40,13 +31,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
id="braintrust_scoring_together_inference",
|
||||
marks=pytest.mark.braintrust_scoring_together_inference,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"scoring": "llm_as_judge",
|
||||
"datasetio": "localfs",
|
||||
"inference": "together",
|
||||
},
|
||||
id="llm_as_judge_scoring_together_inference",
|
||||
marks=pytest.mark.llm_as_judge_scoring_together_inference,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for fixture_name in [
|
||||
"meta_reference_scoring_fireworks_inference",
|
||||
"meta_reference_scoring_together_inference",
|
||||
"basic_scoring_together_inference",
|
||||
"braintrust_scoring_together_inference",
|
||||
]:
|
||||
config.addinivalue_line(
|
||||
|
|
|
@ -19,12 +19,12 @@ def scoring_remote() -> ProviderFixture:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def scoring_meta_reference() -> ProviderFixture:
|
||||
def scoring_basic() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
provider_id="basic",
|
||||
provider_type="inline::basic",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
|
@ -37,14 +37,27 @@ def scoring_braintrust() -> ProviderFixture:
|
|||
providers=[
|
||||
Provider(
|
||||
provider_id="braintrust",
|
||||
provider_type="braintrust",
|
||||
provider_type="inline::braintrust",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"]
|
||||
@pytest.fixture(scope="session")
|
||||
def scoring_llm_as_judge() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="llm-as-judge",
|
||||
provider_type="inline::llm-as-judge",
|
||||
config={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
|
@ -43,6 +43,13 @@ class TestScoring:
|
|||
scoring_stack[Api.datasets],
|
||||
scoring_stack[Api.models],
|
||||
)
|
||||
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
|
||||
provider_id = scoring_fns_list[0].provider_id
|
||||
if provider_id == "llm-as-judge":
|
||||
pytest.skip(
|
||||
f"{provider_id} provider does not support scoring without params"
|
||||
)
|
||||
|
||||
await register_dataset(datasets_impl)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
@ -111,8 +118,8 @@ class TestScoring:
|
|||
|
||||
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
|
||||
provider_id = scoring_fns_list[0].provider_id
|
||||
if provider_id == "braintrust":
|
||||
pytest.skip("Braintrust provider does not support scoring with params")
|
||||
if provider_id == "braintrust" or provider_id == "basic":
|
||||
pytest.skip(f"{provider_id} provider does not support scoring with params")
|
||||
|
||||
# scoring individual rows
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
|
@ -122,7 +129,7 @@ class TestScoring:
|
|||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_functions = {
|
||||
"meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams(
|
||||
"llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams(
|
||||
judge_model="Llama3.1-405B-Instruct",
|
||||
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue