forked from phoenix-oss/llama-stack-mirror
[test automation] support run tests on config file (#730)
# Context For test automation, the end goal is to run a single pytest command from root test directory (llama_stack/providers/tests/.) such that we execute push-blocking tests The work plan: 1) trigger pytest from llama_stack/providers/tests/. 2) use config file to determine what tests and parametrization we want to run # What does this PR do? 1) consolidates the "inference-models" / "embedding-model" / "judge-model" ... options in root conftest.py. Without this change, we will hit into error when trying to run `pytest /Users/sxyi/llama-stack/llama_stack/providers/tests/.` because of duplicated `addoptions` definitions across child conftest files. 2) Add a `config` option to specify test config in YAML. (see [`ci_test_config.yaml`](https://gist.github.com/sixianyi0721/5b37fbce4069139445c2f06f6e42f87e) for example config file) For provider_fixtures, we allow users to use either a default fixture combination or define their own {api:provider} combinations. ``` memory: .... fixtures: provider_fixtures: - default_fixture_param_id: ollama // use default fixture combination with param_id="ollama" in [providers/tests/memory/conftest.py](https://fburl.com/mtjzwsmk) - inference: sentence_transformers memory: faiss - default_fixture_param_id: chroma ``` 3) generate tests according to the config. Logic lives in two places: a) in `{api}/conftest.py::pytest_generate_tests`, we read from config to do parametrization. b) after test collection, in `pytest_collection_modifyitems`, we filter the tests to include only functions listed in config. ## Test Plan 1) `pytest /Users/sxyi/llama-stack/llama_stack/providers/tests/. --collect-only --config=ci_test_config.yaml` Using `--collect-only` tag to print the pytests listed in the config file (`ci_test_config.yaml`). output: [gist](https://gist.github.com/sixianyi0721/05145e60d4d085c17cfb304beeb1e60e) 2) sanity check on `--inference-model` option ``` pytest -v -s -k "ollama" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py ``` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
74e4d520ac
commit
c79b087552
14 changed files with 273 additions and 116 deletions
|
@ -246,7 +246,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||||
)
|
)
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
|
@ -10,6 +10,8 @@ We use `pytest` and all of its dynamism to enable the features needed. Specifica
|
||||||
|
|
||||||
- We use `pytest_configure` to make sure we dynamically add appropriate marks based on the fixtures we make.
|
- We use `pytest_configure` to make sure we dynamically add appropriate marks based on the fixtures we make.
|
||||||
|
|
||||||
|
- We use `pytest_collection_modifyitems` to filter tests based on the test config (if specified).
|
||||||
|
|
||||||
## Common options
|
## Common options
|
||||||
|
|
||||||
All tests support a `--providers` option which can be a string of the form `api1=provider_fixture1,api2=provider_fixture2`. So, when testing safety (which need inference and safety APIs) you can use `--providers inference=together,safety=meta_reference` to use these fixtures in concert.
|
All tests support a `--providers` option which can be a string of the form `api1=provider_fixture1,api2=provider_fixture2`. So, when testing safety (which need inference and safety APIs) you can use `--providers inference=together,safety=meta_reference` to use these fixtures in concert.
|
||||||
|
@ -73,3 +75,15 @@ If you wanted to test a remotely hosted stack, you can use `-m remote` as follow
|
||||||
pytest -s -m remote llama_stack/providers/tests/agents/test_agents.py \
|
pytest -s -m remote llama_stack/providers/tests/agents/test_agents.py \
|
||||||
--env REMOTE_STACK_URL=<...>
|
--env REMOTE_STACK_URL=<...>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Test Config
|
||||||
|
If you want to run a test suite with a custom set of tests and parametrizations, you can define a YAML test config under llama_stack/providers/tests/ folder and pass the filename through `--config` option as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
pytest llama_stack/providers/tests/ --config=ci_test_config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test config format
|
||||||
|
Currently, we support test config on inference, agents and memory api tests.
|
||||||
|
|
||||||
|
Example format of test config can be found in ci_test_config.yaml.
|
||||||
|
|
|
@ -6,10 +6,15 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides
|
from ..conftest import (
|
||||||
|
get_provider_fixture_overrides,
|
||||||
|
get_provider_fixture_overrides_from_test_config,
|
||||||
|
get_test_config_for_api,
|
||||||
|
)
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from ..memory.fixtures import MEMORY_FIXTURES
|
from ..memory.fixtures import MEMORY_FIXTURES
|
||||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||||
|
|
||||||
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||||
from .fixtures import AGENTS_FIXTURES
|
from .fixtures import AGENTS_FIXTURES
|
||||||
|
|
||||||
|
@ -81,23 +86,15 @@ def pytest_configure(config):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
|
||||||
parser.addoption(
|
|
||||||
"--inference-model",
|
|
||||||
action="store",
|
|
||||||
default="meta-llama/Llama-3.2-3B-Instruct",
|
|
||||||
help="Specify the inference model to use for testing",
|
|
||||||
)
|
|
||||||
parser.addoption(
|
|
||||||
"--safety-shield",
|
|
||||||
action="store",
|
|
||||||
default="meta-llama/Llama-Guard-3-1B",
|
|
||||||
help="Specify the safety shield to use for testing",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
shield_id = metafunc.config.getoption("--safety-shield")
|
test_config = get_test_config_for_api(metafunc.config, "agents")
|
||||||
|
shield_id = getattr(
|
||||||
|
test_config, "safety_shield", None
|
||||||
|
) or metafunc.config.getoption("--safety-shield")
|
||||||
|
inference_models = getattr(test_config, "inference_models", None) or [
|
||||||
|
metafunc.config.getoption("--inference-model")
|
||||||
|
]
|
||||||
|
|
||||||
if "safety_shield" in metafunc.fixturenames:
|
if "safety_shield" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
"safety_shield",
|
"safety_shield",
|
||||||
|
@ -105,8 +102,7 @@ def pytest_generate_tests(metafunc):
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
if "inference_model" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
inference_model = metafunc.config.getoption("--inference-model")
|
models = set(inference_models)
|
||||||
models = set({inference_model})
|
|
||||||
if safety_model := safety_model_from_shield(shield_id):
|
if safety_model := safety_model_from_shield(shield_id):
|
||||||
models.add(safety_model)
|
models.add(safety_model)
|
||||||
|
|
||||||
|
@ -124,7 +120,10 @@ def pytest_generate_tests(metafunc):
|
||||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||||
}
|
}
|
||||||
combinations = (
|
combinations = (
|
||||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
get_provider_fixture_overrides_from_test_config(
|
||||||
|
metafunc.config, "agents", DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
|
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
or DEFAULT_PROVIDER_COMBINATIONS
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
)
|
)
|
||||||
metafunc.parametrize("agents_stack", combinations, indirect=True)
|
metafunc.parametrize("agents_stack", combinations, indirect=True)
|
||||||
|
|
|
@ -9,7 +9,9 @@ import pytest
|
||||||
from llama_stack.apis.agents import AgentConfig, Turn
|
from llama_stack.apis.agents import AgentConfig, Turn
|
||||||
from llama_stack.apis.inference import SamplingParams, UserMessage
|
from llama_stack.apis.inference import SamplingParams, UserMessage
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
from .fixtures import pick_inference_model
|
from .fixtures import pick_inference_model
|
||||||
|
|
||||||
from .utils import create_agent_session
|
from .utils import create_agent_session
|
||||||
|
|
55
llama_stack/providers/tests/ci_test_config.yaml
Normal file
55
llama_stack/providers/tests/ci_test_config.yaml
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
inference:
|
||||||
|
tests:
|
||||||
|
- inference/test_vision_inference.py::test_vision_chat_completion_streaming
|
||||||
|
- inference/test_vision_inference.py::test_vision_chat_completion_non_streaming
|
||||||
|
- inference/test_text_inference.py::test_structured_output
|
||||||
|
- inference/test_text_inference.py::test_chat_completion_streaming
|
||||||
|
- inference/test_text_inference.py::test_chat_completion_non_streaming
|
||||||
|
- inference/test_text_inference.py::test_chat_completion_with_tool_calling
|
||||||
|
- inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming
|
||||||
|
|
||||||
|
scenarios:
|
||||||
|
- provider_fixtures:
|
||||||
|
inference: ollama
|
||||||
|
- fixture_combo_id: fireworks
|
||||||
|
- provider_fixtures:
|
||||||
|
inference: together
|
||||||
|
# - inference: tgi
|
||||||
|
# - inference: vllm_remote
|
||||||
|
|
||||||
|
inference_models:
|
||||||
|
- meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
- meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
|
||||||
|
|
||||||
|
agents:
|
||||||
|
tests:
|
||||||
|
- agents/test_agents.py::test_agent_turns_with_safety
|
||||||
|
- agents/test_agents.py::test_rag_agent
|
||||||
|
|
||||||
|
scenarios:
|
||||||
|
- fixture_combo_id: ollama
|
||||||
|
- fixture_combo_id: together
|
||||||
|
- fixture_combo_id: fireworks
|
||||||
|
|
||||||
|
inference_models:
|
||||||
|
- meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
|
||||||
|
safety_shield: meta-llama/Llama-Guard-3-1B
|
||||||
|
|
||||||
|
|
||||||
|
memory:
|
||||||
|
tests:
|
||||||
|
- memory/test_memory.py::test_query_documents
|
||||||
|
|
||||||
|
scenarios:
|
||||||
|
- fixture_combo_id: ollama
|
||||||
|
- provider_fixtures:
|
||||||
|
inference: sentence_transformers
|
||||||
|
memory: faiss
|
||||||
|
- fixture_combo_id: chroma
|
||||||
|
|
||||||
|
inference_models:
|
||||||
|
- meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
|
||||||
|
embedding_model: all-MiniLM-L6-v2
|
|
@ -5,12 +5,16 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Provider
|
from llama_stack.distribution.datatypes import Provider
|
||||||
|
@ -24,6 +28,83 @@ class ProviderFixture(BaseModel):
|
||||||
provider_data: Optional[Dict[str, Any]] = None
|
provider_data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TestScenario(BaseModel):
|
||||||
|
# provider fixtures can be either a mark or a dictionary of api -> providers
|
||||||
|
provider_fixtures: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
fixture_combo_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class APITestConfig(BaseModel):
|
||||||
|
scenarios: List[TestScenario] = Field(default_factory=list)
|
||||||
|
inference_models: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# test name format should be <relative_path.py>::<test_name>
|
||||||
|
tests: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryApiTestConfig(APITestConfig):
|
||||||
|
embedding_model: Optional[str] = Field(default_factory=None)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentsApiTestConfig(APITestConfig):
|
||||||
|
safety_shield: Optional[str] = Field(default_factory=None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig(BaseModel):
|
||||||
|
inference: Optional[APITestConfig] = None
|
||||||
|
agents: Optional[AgentsApiTestConfig] = None
|
||||||
|
memory: Optional[MemoryApiTestConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_config_from_config_file(metafunc_config):
|
||||||
|
config_file = metafunc_config.getoption("--config")
|
||||||
|
if config_file is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
config_file_path = Path(__file__).parent / config_file
|
||||||
|
if not config_file_path.exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
|
||||||
|
)
|
||||||
|
with open(config_file_path, "r") as config_file:
|
||||||
|
config = yaml.safe_load(config_file)
|
||||||
|
return TestConfig(**config)
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_config_for_api(metafunc_config, api):
|
||||||
|
test_config = get_test_config_from_config_file(metafunc_config)
|
||||||
|
if test_config is None:
|
||||||
|
return None
|
||||||
|
return getattr(test_config, api)
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_fixture_overrides_from_test_config(
|
||||||
|
metafunc_config, api, default_provider_fixture_combinations
|
||||||
|
):
|
||||||
|
api_config = get_test_config_for_api(metafunc_config, api)
|
||||||
|
if api_config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
fixture_combo_ids = set()
|
||||||
|
custom_provider_fixture_combos = []
|
||||||
|
for scenario in api_config.scenarios:
|
||||||
|
if scenario.fixture_combo_id:
|
||||||
|
fixture_combo_ids.add(scenario.fixture_combo_id)
|
||||||
|
else:
|
||||||
|
custom_provider_fixture_combos.append(
|
||||||
|
pytest.param(
|
||||||
|
scenario.provider_fixtures,
|
||||||
|
id=scenario.provider_fixtures.get("inference") or "",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(fixture_combo_ids) > 0:
|
||||||
|
for default_fixture in default_provider_fixture_combinations:
|
||||||
|
if default_fixture.id in fixture_combo_ids:
|
||||||
|
custom_provider_fixture_combos.append(default_fixture)
|
||||||
|
return custom_provider_fixture_combos
|
||||||
|
|
||||||
|
|
||||||
def remote_stack_fixture() -> ProviderFixture:
|
def remote_stack_fixture() -> ProviderFixture:
|
||||||
if url := os.getenv("REMOTE_STACK_URL", None):
|
if url := os.getenv("REMOTE_STACK_URL", None):
|
||||||
config = RemoteProviderConfig.from_url(url)
|
config = RemoteProviderConfig.from_url(url)
|
||||||
|
@ -69,10 +150,39 @@ def pytest_addoption(parser):
|
||||||
"Example: --providers inference=ollama,safety=meta-reference"
|
"Example: --providers inference=ollama,safety=meta-reference"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--config",
|
||||||
|
action="store",
|
||||||
|
help="Set test config file (supported format: YAML), e.g. --config=test_config.yml",
|
||||||
|
)
|
||||||
"""Add custom command line options"""
|
"""Add custom command line options"""
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
|
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
|
||||||
)
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--inference-model",
|
||||||
|
action="store",
|
||||||
|
default="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
help="Specify the inference model to use for testing",
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--safety-shield",
|
||||||
|
action="store",
|
||||||
|
default="meta-llama/Llama-Guard-3-1B",
|
||||||
|
help="Specify the safety shield to use for testing",
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--embedding-model",
|
||||||
|
action="store",
|
||||||
|
default=None,
|
||||||
|
help="Specify the embedding model to use for testing",
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--judge-model",
|
||||||
|
action="store",
|
||||||
|
default="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
help="Specify the judge model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_provider_id(providers: Dict[str, str]) -> str:
|
def make_provider_id(providers: Dict[str, str]) -> str:
|
||||||
|
@ -148,6 +258,38 @@ def pytest_itemcollected(item):
|
||||||
item.name = f"{item.name}[{marks}]"
|
item.name = f"{item.name}[{marks}]"
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(session, config, items):
|
||||||
|
test_config = get_test_config_from_config_file(config)
|
||||||
|
if test_config is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
required_tests = defaultdict(set)
|
||||||
|
for api_test_config in [
|
||||||
|
test_config.inference,
|
||||||
|
test_config.memory,
|
||||||
|
test_config.agents,
|
||||||
|
]:
|
||||||
|
if api_test_config is None:
|
||||||
|
continue
|
||||||
|
for test in api_test_config.tests:
|
||||||
|
arr = test.split("::")
|
||||||
|
if len(arr) != 2:
|
||||||
|
raise ValueError(f"Invalid format for test name {test}")
|
||||||
|
test_path, func_name = arr
|
||||||
|
required_tests[Path(__file__).parent / test_path].add(func_name)
|
||||||
|
|
||||||
|
new_items, deselected_items = [], []
|
||||||
|
for item in items:
|
||||||
|
func_name = getattr(item, "originalname", item.name)
|
||||||
|
if func_name in required_tests[item.fspath]:
|
||||||
|
new_items.append(item)
|
||||||
|
continue
|
||||||
|
deselected_items.append(item)
|
||||||
|
|
||||||
|
items[:] = new_items
|
||||||
|
config.hook.pytest_deselected(items=deselected_items)
|
||||||
|
|
||||||
|
|
||||||
pytest_plugins = [
|
pytest_plugins = [
|
||||||
"llama_stack.providers.tests.inference.fixtures",
|
"llama_stack.providers.tests.inference.fixtures",
|
||||||
"llama_stack.providers.tests.safety.fixtures",
|
"llama_stack.providers.tests.safety.fixtures",
|
||||||
|
|
|
@ -76,22 +76,6 @@ def pytest_configure(config):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
|
||||||
parser.addoption(
|
|
||||||
"--inference-model",
|
|
||||||
action="store",
|
|
||||||
default="meta-llama/Llama-3.2-3B-Instruct",
|
|
||||||
help="Specify the inference model to use for testing",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.addoption(
|
|
||||||
"--judge-model",
|
|
||||||
action="store",
|
|
||||||
default="meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
help="Specify the judge model to use for testing",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
if "eval_stack" in metafunc.fixturenames:
|
if "eval_stack" in metafunc.fixturenames:
|
||||||
available_fixtures = {
|
available_fixtures = {
|
||||||
|
|
|
@ -6,26 +6,10 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides
|
from ..conftest import get_provider_fixture_overrides, get_test_config_for_api
|
||||||
|
|
||||||
from .fixtures import INFERENCE_FIXTURES
|
from .fixtures import INFERENCE_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
|
||||||
parser.addoption(
|
|
||||||
"--inference-model",
|
|
||||||
action="store",
|
|
||||||
default=None,
|
|
||||||
help="Specify the inference model to use for testing",
|
|
||||||
)
|
|
||||||
parser.addoption(
|
|
||||||
"--embedding-model",
|
|
||||||
action="store",
|
|
||||||
default=None,
|
|
||||||
help="Specify the embedding model to use for testing",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
for model in ["llama_8b", "llama_3b", "llama_vision"]:
|
for model in ["llama_8b", "llama_3b", "llama_vision"]:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
|
@ -58,16 +42,21 @@ VISION_MODEL_PARAMS = [
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
|
test_config = get_test_config_for_api(metafunc.config, "inference")
|
||||||
|
|
||||||
if "inference_model" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
model = metafunc.config.getoption("--inference-model")
|
cls_name = metafunc.cls.__name__
|
||||||
if model:
|
params = []
|
||||||
|
inference_models = getattr(test_config, "inference_models", [])
|
||||||
|
for model in inference_models:
|
||||||
|
if ("Vision" in cls_name and "Vision" in model) or (
|
||||||
|
"Vision" not in cls_name and "Vision" not in model
|
||||||
|
):
|
||||||
|
params.append(pytest.param(model, id=model))
|
||||||
|
|
||||||
|
if not params:
|
||||||
|
model = metafunc.config.getoption("--inference-model")
|
||||||
params = [pytest.param(model, id="")]
|
params = [pytest.param(model, id="")]
|
||||||
else:
|
|
||||||
cls_name = metafunc.cls.__name__
|
|
||||||
if "Vision" in cls_name:
|
|
||||||
params = VISION_MODEL_PARAMS
|
|
||||||
else:
|
|
||||||
params = MODEL_PARAMS
|
|
||||||
|
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
"inference_model",
|
"inference_model",
|
||||||
|
@ -83,4 +72,13 @@ def pytest_generate_tests(metafunc):
|
||||||
},
|
},
|
||||||
):
|
):
|
||||||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||||
|
if test_config:
|
||||||
|
if custom_fixtures := [
|
||||||
|
(
|
||||||
|
scenario.fixture_combo_id
|
||||||
|
or scenario.provider_fixtures.get("inference")
|
||||||
|
)
|
||||||
|
for scenario in test_config.scenarios
|
||||||
|
]:
|
||||||
|
fixtures = custom_fixtures
|
||||||
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
||||||
|
|
|
@ -301,6 +301,7 @@ async def inference_stack(request, inference_model):
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
models=[
|
models=[
|
||||||
ModelInput(
|
ModelInput(
|
||||||
|
provider_id=inference_fixture.providers[0].provider_id,
|
||||||
model_id=inference_model,
|
model_id=inference_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
|
|
@ -6,7 +6,11 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..conftest import get_provider_fixture_overrides
|
from ..conftest import (
|
||||||
|
get_provider_fixture_overrides,
|
||||||
|
get_provider_fixture_overrides_from_test_config,
|
||||||
|
get_test_config_for_api,
|
||||||
|
)
|
||||||
|
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from .fixtures import MEMORY_FIXTURES
|
from .fixtures import MEMORY_FIXTURES
|
||||||
|
@ -56,15 +60,6 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
|
||||||
parser.addoption(
|
|
||||||
"--embedding-model",
|
|
||||||
action="store",
|
|
||||||
default=None,
|
|
||||||
help="Specify the embedding model to use for testing",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
for fixture_name in MEMORY_FIXTURES:
|
for fixture_name in MEMORY_FIXTURES:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
|
@ -74,8 +69,11 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
|
test_config = get_test_config_for_api(metafunc.config, "memory")
|
||||||
if "embedding_model" in metafunc.fixturenames:
|
if "embedding_model" in metafunc.fixturenames:
|
||||||
model = metafunc.config.getoption("--embedding-model")
|
model = getattr(test_config, "embedding_model", None)
|
||||||
|
# Fall back to the default if not specified by the config file
|
||||||
|
model = model or metafunc.config.getoption("--embedding-model")
|
||||||
if model:
|
if model:
|
||||||
params = [pytest.param(model, id="")]
|
params = [pytest.param(model, id="")]
|
||||||
else:
|
else:
|
||||||
|
@ -89,7 +87,10 @@ def pytest_generate_tests(metafunc):
|
||||||
"memory": MEMORY_FIXTURES,
|
"memory": MEMORY_FIXTURES,
|
||||||
}
|
}
|
||||||
combinations = (
|
combinations = (
|
||||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
get_provider_fixture_overrides_from_test_config(
|
||||||
|
metafunc.config, "memory", DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
|
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
or DEFAULT_PROVIDER_COMBINATIONS
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
)
|
)
|
||||||
metafunc.parametrize("memory_stack", combinations, indirect=True)
|
metafunc.parametrize("memory_stack", combinations, indirect=True)
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
Checkpoint,
|
Checkpoint,
|
||||||
DataConfig,
|
DataConfig,
|
||||||
|
|
|
@ -64,15 +64,6 @@ def pytest_configure(config):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
|
||||||
parser.addoption(
|
|
||||||
"--safety-shield",
|
|
||||||
action="store",
|
|
||||||
default=None,
|
|
||||||
help="Specify the safety shield to use for testing",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SAFETY_SHIELD_PARAMS = [
|
SAFETY_SHIELD_PARAMS = [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
"meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"
|
"meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"
|
||||||
|
|
|
@ -55,21 +55,6 @@ def pytest_configure(config):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
|
||||||
parser.addoption(
|
|
||||||
"--inference-model",
|
|
||||||
action="store",
|
|
||||||
default="meta-llama/Llama-3.2-3B-Instruct",
|
|
||||||
help="Specify the inference model to use for testing",
|
|
||||||
)
|
|
||||||
parser.addoption(
|
|
||||||
"--judge-model",
|
|
||||||
action="store",
|
|
||||||
default="meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
help="Specify the judge model to use for testing",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
judge_model = metafunc.config.getoption("--judge-model")
|
judge_model = metafunc.config.getoption("--judge-model")
|
||||||
if "judge_model" in metafunc.fixturenames:
|
if "judge_model" in metafunc.fixturenames:
|
||||||
|
|
|
@ -34,21 +34,6 @@ def pytest_configure(config):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
|
||||||
parser.addoption(
|
|
||||||
"--inference-model",
|
|
||||||
action="store",
|
|
||||||
default="meta-llama/Llama-3.2-3B-Instruct",
|
|
||||||
help="Specify the inference model to use for testing",
|
|
||||||
)
|
|
||||||
parser.addoption(
|
|
||||||
"--safety-shield",
|
|
||||||
action="store",
|
|
||||||
default="meta-llama/Llama-Guard-3-1B",
|
|
||||||
help="Specify the safety shield to use for testing",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
if "tools_stack" in metafunc.fixturenames:
|
if "tools_stack" in metafunc.fixturenames:
|
||||||
available_fixtures = {
|
available_fixtures = {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue