This commit is contained in:
Xi Yan 2024-11-06 21:10:54 -08:00
parent 56239fce90
commit 413a1b6d8b
13 changed files with 293 additions and 15 deletions

View file

@ -151,4 +151,5 @@ pytest_plugins = [
"llama_stack.providers.tests.agents.fixtures",
"llama_stack.providers.tests.datasetio.fixtures",
"llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures",
]

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from ..scoring.fixtures import SCORING_FIXTURES
from .fixtures import EVAL_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"eval": "meta_reference",
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "fireworks",
},
id="meta_reference_eval_fireworks_inference",
marks=pytest.mark.meta_reference_eval_fireworks_inference,
),
pytest.param(
{
"eval": "meta_reference",
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "together",
},
id="meta_reference_eval_together_inference",
marks=pytest.mark.meta_reference_eval_together_inference,
),
]
def pytest_configure(config):
for fixture_name in [
"meta_reference_eval_fireworks_inference",
"meta_reference_eval_together_inference",
]:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default="Llama3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
def pytest_generate_tests(metafunc):
if "eval_stack" in metafunc.fixturenames:
available_fixtures = {
"eval": EVAL_FIXTURES,
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("eval_stack", combinations, indirect=True)

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def eval_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def eval_meta_reference() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config={},
)
],
)
EVAL_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def eval_stack(request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["datasetio", "eval", "scoring", "inference"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2(
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
providers,
provider_data,
)
return (
impls[Api.eval],
impls[Api.eval_tasks],
impls[Api.scoring],
impls[Api.scoring_functions],
impls[Api.datasetio],
impls[Api.datasets],
)

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
# How to run this test:
#
# pytest llama_stack/providers/tests/eval/test_eval.py
# -m "meta_reference"
# -v -s --tb=short --disable-warnings
class Testeval:
@pytest.mark.asyncio
async def test_eval_tasks_list(self, eval_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, eval_tasks_impl, _, _, _, _ = eval_stack
response = await eval_tasks_impl.list_eval_tasks()
assert isinstance(response, list)
print(response)

View file

@ -21,12 +21,24 @@ DEFAULT_PROVIDER_COMBINATIONS = [
},
id="meta_reference_scoring_fireworks_inference",
marks=pytest.mark.meta_reference_scoring_fireworks_inference,
)
),
pytest.param(
{
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "together",
},
id="meta_reference_scoring_together_inference",
marks=pytest.mark.meta_reference_scoring_together_inference,
),
]
def pytest_configure(config):
for fixture_name in ["meta_reference_scoring_fireworks_inference"]:
for fixture_name in [
"meta_reference_scoring_fireworks_inference",
"meta_reference_scoring_together_inference",
]:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",