refactor(test): introduce --stack-config and simplify options (#1404)

You now run the integration tests with these options:

```bash
Custom options:
  --stack-config=STACK_CONFIG
                        a 'pointer' to the stack. this can be either be:
                        (a) a template name like `fireworks`, or
                        (b) a path to a run.yaml file, or
                        (c) an adhoc config spec, e.g.
                        `inference=fireworks,safety=llama-guard,agents=meta-
                        reference`
  --env=ENV             Set environment variables, e.g. --env KEY=value
  --text-model=TEXT_MODEL
                        comma-separated list of text models. Fixture name:
                        text_model_id
  --vision-model=VISION_MODEL
                        comma-separated list of vision models. Fixture name:
                        vision_model_id
  --embedding-model=EMBEDDING_MODEL
                        comma-separated list of embedding models. Fixture name:
                        embedding_model_id
  --safety-shield=SAFETY_SHIELD
                        comma-separated list of safety shields. Fixture name:
                        shield_id
  --judge-model=JUDGE_MODEL
                        comma-separated list of judge models. Fixture name:
                        judge_model_id
  --embedding-dimension=EMBEDDING_DIMENSION
                        Output dimensionality of the embedding model to use for
                        testing. Default: 384
  --record-responses    Record new API responses instead of using cached ones.
  --report=REPORT       Path where the test report should be written, e.g.
                        --report=/path/to/report.md

```

Importantly, if you don't specify any of the models (text-model,
vision-model, etc.) the relevant tests will get **skipped!**

This will make running tests somewhat more annoying since all options
will need to be specified. We will make this easier by adding some easy
wrapper yaml configs.

## Test Plan

Example:

```bash
ashwin@ashwin-mbp ~/local/llama-stack/tests/integration (unify_tests) $ 
LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/test_text_inference.py \
   --text-model meta-llama/Llama-3.2-3B-Instruct 
```
This commit is contained in:
Ashwin Bharambe 2025-03-05 17:02:02 -08:00 committed by GitHub
parent a0d6b165b0
commit 2fe976ed0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 536 additions and 1144 deletions

View file

@ -5,18 +5,12 @@
# the root directory of this source tree.
import importlib
import os
from collections import defaultdict
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import pytest
from pytest import CollectReport
from termcolor import cprint
from llama_stack.env import get_env_or_fail
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_list import (
all_registered_models,
@ -68,27 +62,16 @@ SUPPORTED_MODELS = {
class Report:
def __init__(self, report_path: Optional[str] = None):
if os.environ.get("LLAMA_STACK_CONFIG"):
config_path_or_template_name = get_env_or_fail("LLAMA_STACK_CONFIG")
if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name)
else:
config_path = Path(
importlib.resources.files("llama_stack") / f"templates/{config_path_or_template_name}/run.yaml"
)
if not config_path.exists():
raise ValueError(f"Config file {config_path} does not exist")
self.output_path = Path(config_path.parent / "report.md")
self.distro_name = None
elif os.environ.get("LLAMA_STACK_BASE_URL"):
url = get_env_or_fail("LLAMA_STACK_BASE_URL")
self.distro_name = urlparse(url).netloc
if report_path is None:
raise ValueError("Report path must be provided when LLAMA_STACK_BASE_URL is set")
self.output_path = Path(report_path)
else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
def __init__(self, config):
self.distro_name = None
self.config = config
stack_config = self.config.getoption("--stack-config")
if stack_config:
is_url = stack_config.startswith("http") or "//" in stack_config
is_yaml = stack_config.endswith(".yaml")
if not is_url and not is_yaml:
self.distro_name = stack_config
self.report_data = defaultdict(dict)
# test function -> test nodeid
@ -109,6 +92,9 @@ class Report:
self.test_data[report.nodeid] = outcome
def pytest_sessionfinish(self, session):
if not self.client:
return
report = []
report.append(f"# Report for {self.distro_name} distribution")
report.append("\n## Supported Models")
@ -153,7 +139,8 @@ class Report:
for test_name in tests:
model_id = self.text_model_id if "text" in test_name else self.vision_model_id
test_nodeids = self.test_name_to_nodeid[test_name]
assert len(test_nodeids) > 0
if not test_nodeids:
continue
# There might be more than one parametrizations for the same test function. We take
# the result of the first one for now. Ideally we should mark the test as failed if
@ -179,7 +166,8 @@ class Report:
for capa, tests in capa_map.items():
for test_name in tests:
test_nodeids = self.test_name_to_nodeid[test_name]
assert len(test_nodeids) > 0
if not test_nodeids:
continue
test_table.append(
f"| {provider_str} | /{api} | {capa} | {test_name} | {self._print_result_icon(self.test_data[test_nodeids[0]])} |"
)
@ -195,16 +183,15 @@ class Report:
self.test_name_to_nodeid[func_name].append(item.nodeid)
# Get values from fixtures for report output
if "text_model_id" in item.funcargs:
text_model = item.funcargs["text_model_id"].split("/")[1]
if model_id := item.funcargs.get("text_model_id"):
text_model = model_id.split("/")[1]
self.text_model_id = self.text_model_id or text_model
elif "vision_model_id" in item.funcargs:
vision_model = item.funcargs["vision_model_id"].split("/")[1]
elif model_id := item.funcargs.get("vision_model_id"):
vision_model = model_id.split("/")[1]
self.vision_model_id = self.vision_model_id or vision_model
if self.client is None and "llama_stack_client" in item.funcargs:
self.client = item.funcargs["llama_stack_client"]
self.distro_name = self.distro_name or self.client.async_client.config.image_name
if not self.client:
self.client = item.funcargs.get("llama_stack_client")
def _print_result_icon(self, result):
if result == "Passed":