llama-stack-mirror/llama_stack/providers/tests/conftest.py
Sixian Yi c79b087552
[test automation] support run tests on config file (#730)
# Context
For test automation, the end goal is to run a single pytest command from
root test directory (llama_stack/providers/tests/.) such that we execute
push-blocking tests

The work plan: 
1) trigger pytest from llama_stack/providers/tests/.
2) use config file to determine what tests and parametrization we want
to run

# What does this PR do?
1) consolidates the "inference-models" / "embedding-model" /
"judge-model" ... options in root conftest.py. Without this change, we
will hit into error when trying to run `pytest
/Users/sxyi/llama-stack/llama_stack/providers/tests/.` because of
duplicated `addoptions` definitions across child conftest files.

2) Add a `config` option to specify test config in YAML. (see
[`ci_test_config.yaml`](https://gist.github.com/sixianyi0721/5b37fbce4069139445c2f06f6e42f87e)
for example config file)

For provider_fixtures, we allow users to use either a default fixture
combination or define their own {api:provider} combinations.

```

memory:
....
  fixtures:
    provider_fixtures:
    - default_fixture_param_id: ollama // use default fixture combination with param_id="ollama" in [providers/tests/memory/conftest.py](https://fburl.com/mtjzwsmk)
    - inference: sentence_transformers
      memory: faiss
    - default_fixture_param_id: chroma

```
3) generate tests according to the config. Logic lives in two places: 
a) in `{api}/conftest.py::pytest_generate_tests`, we read from config to
do parametrization.
b) after test collection, in `pytest_collection_modifyitems`, we filter
the tests to include only functions listed in config.

## Test Plan
1) `pytest /Users/sxyi/llama-stack/llama_stack/providers/tests/.
--collect-only --config=ci_test_config.yaml`

Using `--collect-only` tag to print the pytests listed in the config
file (`ci_test_config.yaml`).

output:
[gist](https://gist.github.com/sixianyi0721/05145e60d4d085c17cfb304beeb1e60e)


2) sanity check on `--inference-model` option

```
pytest -v -s -k "ollama" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py
```


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2025-01-16 12:05:49 -08:00

303 lines
9.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional
import pytest
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from termcolor import colored
from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.datatypes import RemoteProviderConfig
from .env import get_env_or_fail
class ProviderFixture(BaseModel):
providers: List[Provider]
provider_data: Optional[Dict[str, Any]] = None
class TestScenario(BaseModel):
# provider fixtures can be either a mark or a dictionary of api -> providers
provider_fixtures: Dict[str, str] = Field(default_factory=dict)
fixture_combo_id: Optional[str] = None
class APITestConfig(BaseModel):
scenarios: List[TestScenario] = Field(default_factory=list)
inference_models: List[str] = Field(default_factory=list)
# test name format should be <relative_path.py>::<test_name>
tests: List[str] = Field(default_factory=list)
class MemoryApiTestConfig(APITestConfig):
embedding_model: Optional[str] = Field(default_factory=None)
class AgentsApiTestConfig(APITestConfig):
safety_shield: Optional[str] = Field(default_factory=None)
class TestConfig(BaseModel):
inference: Optional[APITestConfig] = None
agents: Optional[AgentsApiTestConfig] = None
memory: Optional[MemoryApiTestConfig] = None
def get_test_config_from_config_file(metafunc_config):
config_file = metafunc_config.getoption("--config")
if config_file is None:
return None
config_file_path = Path(__file__).parent / config_file
if not config_file_path.exists():
raise ValueError(
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)
with open(config_file_path, "r") as config_file:
config = yaml.safe_load(config_file)
return TestConfig(**config)
def get_test_config_for_api(metafunc_config, api):
test_config = get_test_config_from_config_file(metafunc_config)
if test_config is None:
return None
return getattr(test_config, api)
def get_provider_fixture_overrides_from_test_config(
metafunc_config, api, default_provider_fixture_combinations
):
api_config = get_test_config_for_api(metafunc_config, api)
if api_config is None:
return None
fixture_combo_ids = set()
custom_provider_fixture_combos = []
for scenario in api_config.scenarios:
if scenario.fixture_combo_id:
fixture_combo_ids.add(scenario.fixture_combo_id)
else:
custom_provider_fixture_combos.append(
pytest.param(
scenario.provider_fixtures,
id=scenario.provider_fixtures.get("inference") or "",
)
)
if len(fixture_combo_ids) > 0:
for default_fixture in default_provider_fixture_combinations:
if default_fixture.id in fixture_combo_ids:
custom_provider_fixture_combos.append(default_fixture)
return custom_provider_fixture_combos
def remote_stack_fixture() -> ProviderFixture:
if url := os.getenv("REMOTE_STACK_URL", None):
config = RemoteProviderConfig.from_url(url)
else:
config = RemoteProviderConfig(
host=get_env_or_fail("REMOTE_STACK_HOST"),
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
)
return ProviderFixture(
providers=[
Provider(
provider_id="test::remote",
provider_type="test::remote",
config=config.model_dump(),
)
],
)
def pytest_configure(config):
config.option.tbstyle = "short"
config.option.disable_warnings = True
"""Load environment variables at start of test run"""
# Load from .env file if it exists
env_file = Path(__file__).parent / ".env"
if env_file.exists():
load_dotenv(env_file)
# Load any environment variables passed via --env
env_vars = config.getoption("--env") or []
for env_var in env_vars:
key, value = env_var.split("=", 1)
os.environ[key] = value
def pytest_addoption(parser):
parser.addoption(
"--providers",
default="",
help=(
"Provider configuration in format: api1=provider1,api2=provider2. "
"Example: --providers inference=ollama,safety=meta-reference"
),
)
parser.addoption(
"--config",
action="store",
help="Set test config file (supported format: YAML), e.g. --config=test_config.yml",
)
"""Add custom command line options"""
parser.addoption(
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
)
parser.addoption(
"--inference-model",
action="store",
default="meta-llama/Llama-3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
parser.addoption(
"--safety-shield",
action="store",
default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield to use for testing",
)
parser.addoption(
"--embedding-model",
action="store",
default=None,
help="Specify the embedding model to use for testing",
)
parser.addoption(
"--judge-model",
action="store",
default="meta-llama/Llama-3.1-8B-Instruct",
help="Specify the judge model to use for testing",
)
def make_provider_id(providers: Dict[str, str]) -> str:
return ":".join(f"{api}={provider}" for api, provider in sorted(providers.items()))
def get_provider_marks(providers: Dict[str, str]) -> List[Any]:
marks = []
for provider in providers.values():
marks.append(getattr(pytest.mark, provider))
return marks
def get_provider_fixture_overrides(
config, available_fixtures: Dict[str, List[str]]
) -> Optional[List[pytest.param]]:
provider_str = config.getoption("--providers")
if not provider_str:
return None
fixture_dict = parse_fixture_string(provider_str, available_fixtures)
return [
pytest.param(
fixture_dict,
id=make_provider_id(fixture_dict),
marks=get_provider_marks(fixture_dict),
)
]
def parse_fixture_string(
provider_str: str, available_fixtures: Dict[str, List[str]]
) -> Dict[str, str]:
"""Parse provider string of format 'api1=provider1,api2=provider2'"""
if not provider_str:
return {}
fixtures = {}
pairs = provider_str.split(",")
for pair in pairs:
if "=" not in pair:
raise ValueError(
f"Invalid provider specification: {pair}. Expected format: api=provider"
)
api, fixture = pair.split("=")
if api not in available_fixtures:
raise ValueError(
f"Unknown API: {api}. Available APIs: {list(available_fixtures.keys())}"
)
if fixture not in available_fixtures[api]:
raise ValueError(
f"Unknown provider '{fixture}' for API '{api}'. "
f"Available providers: {list(available_fixtures[api])}"
)
fixtures[api] = fixture
# Check that all provided APIs are supported
for api in available_fixtures.keys():
if api not in fixtures:
raise ValueError(
f"Missing provider fixture for API '{api}'. Available providers: "
f"{list(available_fixtures[api])}"
)
return fixtures
def pytest_itemcollected(item):
# Get all markers as a list
filtered = ("asyncio", "parametrize")
marks = [mark.name for mark in item.iter_markers() if mark.name not in filtered]
if marks:
marks = colored(",".join(marks), "yellow")
item.name = f"{item.name}[{marks}]"
def pytest_collection_modifyitems(session, config, items):
test_config = get_test_config_from_config_file(config)
if test_config is None:
return
required_tests = defaultdict(set)
for api_test_config in [
test_config.inference,
test_config.memory,
test_config.agents,
]:
if api_test_config is None:
continue
for test in api_test_config.tests:
arr = test.split("::")
if len(arr) != 2:
raise ValueError(f"Invalid format for test name {test}")
test_path, func_name = arr
required_tests[Path(__file__).parent / test_path].add(func_name)
new_items, deselected_items = [], []
for item in items:
func_name = getattr(item, "originalname", item.name)
if func_name in required_tests[item.fspath]:
new_items.append(item)
continue
deselected_items.append(item)
items[:] = new_items
config.hook.pytest_deselected(items=deselected_items)
pytest_plugins = [
"llama_stack.providers.tests.inference.fixtures",
"llama_stack.providers.tests.safety.fixtures",
"llama_stack.providers.tests.memory.fixtures",
"llama_stack.providers.tests.agents.fixtures",
"llama_stack.providers.tests.datasetio.fixtures",
"llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures",
"llama_stack.providers.tests.post_training.fixtures",
"llama_stack.providers.tests.tools.fixtures",
]