forked from phoenix-oss/llama-stack-mirror
[test automation] support run tests on config file (#730)
# Context For test automation, the end goal is to run a single pytest command from root test directory (llama_stack/providers/tests/.) such that we execute push-blocking tests The work plan: 1) trigger pytest from llama_stack/providers/tests/. 2) use config file to determine what tests and parametrization we want to run # What does this PR do? 1) consolidates the "inference-models" / "embedding-model" / "judge-model" ... options in root conftest.py. Without this change, we will hit into error when trying to run `pytest /Users/sxyi/llama-stack/llama_stack/providers/tests/.` because of duplicated `addoptions` definitions across child conftest files. 2) Add a `config` option to specify test config in YAML. (see [`ci_test_config.yaml`](https://gist.github.com/sixianyi0721/5b37fbce4069139445c2f06f6e42f87e) for example config file) For provider_fixtures, we allow users to use either a default fixture combination or define their own {api:provider} combinations. ``` memory: .... fixtures: provider_fixtures: - default_fixture_param_id: ollama // use default fixture combination with param_id="ollama" in [providers/tests/memory/conftest.py](https://fburl.com/mtjzwsmk) - inference: sentence_transformers memory: faiss - default_fixture_param_id: chroma ``` 3) generate tests according to the config. Logic lives in two places: a) in `{api}/conftest.py::pytest_generate_tests`, we read from config to do parametrization. b) after test collection, in `pytest_collection_modifyitems`, we filter the tests to include only functions listed in config. ## Test Plan 1) `pytest /Users/sxyi/llama-stack/llama_stack/providers/tests/. --collect-only --config=ci_test_config.yaml` Using `--collect-only` tag to print the pytests listed in the config file (`ci_test_config.yaml`). output: [gist](https://gist.github.com/sixianyi0721/05145e60d4d085c17cfb304beeb1e60e) 2) sanity check on `--inference-model` option ``` pytest -v -s -k "ollama" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py ``` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
74e4d520ac
commit
c79b087552
14 changed files with 273 additions and 116 deletions
|
@ -5,12 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.distribution.datatypes import Provider
|
||||
|
@ -24,6 +28,83 @@ class ProviderFixture(BaseModel):
|
|||
provider_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TestScenario(BaseModel):
|
||||
# provider fixtures can be either a mark or a dictionary of api -> providers
|
||||
provider_fixtures: Dict[str, str] = Field(default_factory=dict)
|
||||
fixture_combo_id: Optional[str] = None
|
||||
|
||||
|
||||
class APITestConfig(BaseModel):
|
||||
scenarios: List[TestScenario] = Field(default_factory=list)
|
||||
inference_models: List[str] = Field(default_factory=list)
|
||||
|
||||
# test name format should be <relative_path.py>::<test_name>
|
||||
tests: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MemoryApiTestConfig(APITestConfig):
|
||||
embedding_model: Optional[str] = Field(default_factory=None)
|
||||
|
||||
|
||||
class AgentsApiTestConfig(APITestConfig):
|
||||
safety_shield: Optional[str] = Field(default_factory=None)
|
||||
|
||||
|
||||
class TestConfig(BaseModel):
|
||||
inference: Optional[APITestConfig] = None
|
||||
agents: Optional[AgentsApiTestConfig] = None
|
||||
memory: Optional[MemoryApiTestConfig] = None
|
||||
|
||||
|
||||
def get_test_config_from_config_file(metafunc_config):
|
||||
config_file = metafunc_config.getoption("--config")
|
||||
if config_file is None:
|
||||
return None
|
||||
|
||||
config_file_path = Path(__file__).parent / config_file
|
||||
if not config_file_path.exists():
|
||||
raise ValueError(
|
||||
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
|
||||
)
|
||||
with open(config_file_path, "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
return TestConfig(**config)
|
||||
|
||||
|
||||
def get_test_config_for_api(metafunc_config, api):
|
||||
test_config = get_test_config_from_config_file(metafunc_config)
|
||||
if test_config is None:
|
||||
return None
|
||||
return getattr(test_config, api)
|
||||
|
||||
|
||||
def get_provider_fixture_overrides_from_test_config(
|
||||
metafunc_config, api, default_provider_fixture_combinations
|
||||
):
|
||||
api_config = get_test_config_for_api(metafunc_config, api)
|
||||
if api_config is None:
|
||||
return None
|
||||
|
||||
fixture_combo_ids = set()
|
||||
custom_provider_fixture_combos = []
|
||||
for scenario in api_config.scenarios:
|
||||
if scenario.fixture_combo_id:
|
||||
fixture_combo_ids.add(scenario.fixture_combo_id)
|
||||
else:
|
||||
custom_provider_fixture_combos.append(
|
||||
pytest.param(
|
||||
scenario.provider_fixtures,
|
||||
id=scenario.provider_fixtures.get("inference") or "",
|
||||
)
|
||||
)
|
||||
|
||||
if len(fixture_combo_ids) > 0:
|
||||
for default_fixture in default_provider_fixture_combinations:
|
||||
if default_fixture.id in fixture_combo_ids:
|
||||
custom_provider_fixture_combos.append(default_fixture)
|
||||
return custom_provider_fixture_combos
|
||||
|
||||
|
||||
def remote_stack_fixture() -> ProviderFixture:
|
||||
if url := os.getenv("REMOTE_STACK_URL", None):
|
||||
config = RemoteProviderConfig.from_url(url)
|
||||
|
@ -69,10 +150,39 @@ def pytest_addoption(parser):
|
|||
"Example: --providers inference=ollama,safety=meta-reference"
|
||||
),
|
||||
)
|
||||
parser.addoption(
|
||||
"--config",
|
||||
action="store",
|
||||
help="Set test config file (supported format: YAML), e.g. --config=test_config.yml",
|
||||
)
|
||||
"""Add custom command line options"""
|
||||
parser.addoption(
|
||||
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
|
||||
)
|
||||
parser.addoption(
|
||||
"--inference-model",
|
||||
action="store",
|
||||
default="meta-llama/Llama-3.2-3B-Instruct",
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default="meta-llama/Llama-Guard-3-1B",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--embedding-model",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the embedding model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--judge-model",
|
||||
action="store",
|
||||
default="meta-llama/Llama-3.1-8B-Instruct",
|
||||
help="Specify the judge model to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def make_provider_id(providers: Dict[str, str]) -> str:
|
||||
|
@ -148,6 +258,38 @@ def pytest_itemcollected(item):
|
|||
item.name = f"{item.name}[{marks}]"
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(session, config, items):
|
||||
test_config = get_test_config_from_config_file(config)
|
||||
if test_config is None:
|
||||
return
|
||||
|
||||
required_tests = defaultdict(set)
|
||||
for api_test_config in [
|
||||
test_config.inference,
|
||||
test_config.memory,
|
||||
test_config.agents,
|
||||
]:
|
||||
if api_test_config is None:
|
||||
continue
|
||||
for test in api_test_config.tests:
|
||||
arr = test.split("::")
|
||||
if len(arr) != 2:
|
||||
raise ValueError(f"Invalid format for test name {test}")
|
||||
test_path, func_name = arr
|
||||
required_tests[Path(__file__).parent / test_path].add(func_name)
|
||||
|
||||
new_items, deselected_items = [], []
|
||||
for item in items:
|
||||
func_name = getattr(item, "originalname", item.name)
|
||||
if func_name in required_tests[item.fspath]:
|
||||
new_items.append(item)
|
||||
continue
|
||||
deselected_items.append(item)
|
||||
|
||||
items[:] = new_items
|
||||
config.hook.pytest_deselected(items=deselected_items)
|
||||
|
||||
|
||||
pytest_plugins = [
|
||||
"llama_stack.providers.tests.inference.fixtures",
|
||||
"llama_stack.providers.tests.safety.fixtures",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue