[Evals API][10/n] API updates for EvalTaskDef + new test migration (#379)

* wip

* scoring fn api

* eval api

* eval task

* evaluate api update

* pre commit

* unwrap context -> config

* config field doc

* typo

* naming fix

* separate benchmark / app eval

* api name

* rename

* wip tests

* wip

* datasetio test

* delete unused

* fixture

* scoring resolve

* fix scoring register

* scoring test pass

* score batch

* scoring fix

* fix eval

* test eval works

* remove type ignore

* api refactor

* add default task_eval_id for routing

* add eval_id for jobs

* remove type ignore

* only keep 1 run_eval

* fix optional

* register task required

* register task required

* delete old tests

* delete old tests

* fixture return impl
This commit is contained in:
Xi Yan 2024-11-07 21:24:12 -08:00 committed by GitHub
parent 8350f2df4c
commit 6192bf43a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 916 additions and 389 deletions

View file

@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import SCORING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "fireworks",
},
id="meta_reference_scoring_fireworks_inference",
marks=pytest.mark.meta_reference_scoring_fireworks_inference,
),
pytest.param(
{
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "together",
},
id="meta_reference_scoring_together_inference",
marks=pytest.mark.meta_reference_scoring_together_inference,
),
]
def pytest_configure(config):
for fixture_name in [
"meta_reference_scoring_fireworks_inference",
"meta_reference_scoring_together_inference",
]:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default="Llama3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
def pytest_generate_tests(metafunc):
if "scoring_stack" in metafunc.fixturenames:
available_fixtures = {
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("scoring_stack", combinations, indirect=True)

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def scoring_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def scoring_meta_reference() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config={},
)
],
)
SCORING_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def scoring_stack(request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["datasetio", "scoring", "inference"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2(
[Api.scoring, Api.datasetio, Api.inference],
providers,
provider_data,
)
return (
impls[Api.scoring],
impls[Api.scoring_functions],
impls[Api.datasetio],
impls[Api.datasets],
)

View file

@ -1,17 +0,0 @@
providers:
datasetio:
- provider_id: test-meta
provider_type: meta-reference
config: {}
scoring:
- provider_id: test-meta
provider_type: meta-reference
config: {}
- provider_id: test-braintrust
provider_type: braintrust
config: {}
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009

View file

@ -3,150 +3,109 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import pytest
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \
# --tb=short --disable-warnings
# ```
# pytest llama_stack/providers/tests/scoring/test_scoring.py
# -m "meta_reference"
# -v -s --tb=short --disable-warnings
@pytest_asyncio.fixture(scope="session")
async def scoring_settings():
impls = await resolve_impls_for_test(
Api.scoring, deps=[Api.datasetio, Api.inference]
)
return {
"scoring_impl": impls[Api.scoring],
"scoring_functions_impl": impls[Api.scoring_functions],
"datasets_impl": impls[Api.datasets],
}
class TestScoring:
@pytest.mark.asyncio
async def test_scoring_functions_list(self, scoring_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, scoring_functions_impl, _, _ = scoring_stack
response = await scoring_functions_impl.list_scoring_functions()
assert isinstance(response, list)
assert len(response) > 0
@pytest_asyncio.fixture(scope="session")
async def provider_scoring_functions():
return {
"meta-reference": {
"meta-reference::equality",
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
},
"braintrust": {
"braintrust::factuality",
"braintrust::answer-correctness",
},
}
@pytest.mark.asyncio
async def test_scoring_functions_list(scoring_settings, provider_scoring_functions):
scoring_impl = scoring_settings["scoring_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
scoring_functions = await scoring_functions_impl.list_scoring_functions()
assert isinstance(scoring_functions, list)
assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions]
# get current provider_type we're testing
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
provider_type = provider.__provider_spec__.provider_type
for x in provider_scoring_functions[provider_type]:
assert x in function_ids
@pytest.mark.asyncio
async def test_scoring_functions_register(scoring_settings):
scoring_impl = scoring_settings["scoring_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
datasets_impl = scoring_settings["datasets_impl"]
# get current provider_type we're testing
scoring_functions = await scoring_functions_impl.list_scoring_functions()
function_ids = [f.identifier for f in scoring_functions]
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
provider_type = provider.__provider_spec__.provider_type
if provider_type not in ("meta-reference"):
pytest.skip(
"Other scoring providers don't support registering scoring functions."
@pytest.mark.asyncio
async def test_scoring_score(self, scoring_stack):
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
scoring_stack
)
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>"""
# register the scoring function
await scoring_functions_impl.register_scoring_function(
ScoringFnDefWithProvider(
identifier="meta-reference::llm_as_judge_8b_random",
description="Llm As Judge Scoring Function",
parameters=[],
return_type=NumberType(),
context=LLMAsJudgeContext(
prompt_template=test_prompt,
judge_model="Llama3.1-8B-Instruct",
judge_score_regex=[r"Number: (\d+)"],
),
provider_id="test-meta",
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
)
assert len(rows.rows) == 3
scoring_functions = await scoring_functions_impl.list_scoring_functions()
assert isinstance(scoring_functions, list)
assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions]
assert "meta-reference::llm_as_judge_8b_random" in function_ids
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": None,
"meta-reference::equality": None,
}
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
# test score using newly registered scoring function
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
response = await scoring_impl.score_batch(
dataset_id=response[0].identifier,
scoring_functions=[
"meta-reference::llm_as_judge_8b_random",
],
)
assert "meta-reference::llm_as_judge_8b_random" in response.results
# score batch
response = await scoring_impl.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5
@pytest.mark.asyncio
async def test_scoring_score_with_params(self, scoring_stack):
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
scoring_stack
)
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
@pytest.mark.asyncio
async def test_scoring_score(scoring_settings, provider_scoring_functions):
scoring_impl = scoring_settings["scoring_impl"]
datasets_impl = scoring_settings["datasets_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
await register_dataset(datasets_impl)
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
response = await datasets_impl.list_datasets()
assert len(response) == 1
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
judge_score_regexes=[r"Score: (\d+)"],
)
}
# get current provider_type we're testing
scoring_functions = await scoring_functions_impl.list_scoring_functions()
function_ids = [f.identifier for f in scoring_functions]
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
provider_type = provider.__provider_spec__.provider_type
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
response = await scoring_impl.score_batch(
dataset_id=response[0].identifier,
scoring_functions=list(provider_scoring_functions[provider_type]),
)
assert len(response.results) == len(provider_scoring_functions[provider_type])
for x in provider_scoring_functions[provider_type]:
assert x in response.results
# score batch
response = await scoring_impl.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5