refactor(test): introduce --stack-config and simplify options (#1404)

You now run the integration tests with these options:

```bash
Custom options:
  --stack-config=STACK_CONFIG
                        a 'pointer' to the stack. this can be either be:
                        (a) a template name like `fireworks`, or
                        (b) a path to a run.yaml file, or
                        (c) an adhoc config spec, e.g.
                        `inference=fireworks,safety=llama-guard,agents=meta-
                        reference`
  --env=ENV             Set environment variables, e.g. --env KEY=value
  --text-model=TEXT_MODEL
                        comma-separated list of text models. Fixture name:
                        text_model_id
  --vision-model=VISION_MODEL
                        comma-separated list of vision models. Fixture name:
                        vision_model_id
  --embedding-model=EMBEDDING_MODEL
                        comma-separated list of embedding models. Fixture name:
                        embedding_model_id
  --safety-shield=SAFETY_SHIELD
                        comma-separated list of safety shields. Fixture name:
                        shield_id
  --judge-model=JUDGE_MODEL
                        comma-separated list of judge models. Fixture name:
                        judge_model_id
  --embedding-dimension=EMBEDDING_DIMENSION
                        Output dimensionality of the embedding model to use for
                        testing. Default: 384
  --record-responses    Record new API responses instead of using cached ones.
  --report=REPORT       Path where the test report should be written, e.g.
                        --report=/path/to/report.md

```

Importantly, if you don't specify any of the models (text-model,
vision-model, etc.) the relevant tests will get **skipped!**

This will make running tests somewhat more annoying since all options
will need to be specified. We will make this easier by adding some easy
wrapper yaml configs.

## Test Plan

Example:

```bash
ashwin@ashwin-mbp ~/local/llama-stack/tests/integration (unify_tests) $ 
LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/test_text_inference.py \
   --text-model meta-llama/Llama-3.2-3B-Instruct 
```
This commit is contained in:
Ashwin Bharambe 2025-03-05 17:02:02 -08:00 committed by GitHub
parent a0d6b165b0
commit 2fe976ed0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 536 additions and 1144 deletions

View file

@ -7,6 +7,7 @@
import importlib.resources
import os
import re
import tempfile
from typing import Any, Dict, Optional
import yaml
@ -33,10 +34,11 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import Api
@ -228,3 +230,53 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config))
def run_config_from_adhoc_config_spec(
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
) -> StackRunConfig:
"""
Create an adhoc distribution from a list of API providers.
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
"""
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = provider_registry or get_provider_registry()
distro_dir = tempfile.mkdtemp()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
api = Api(api_str)
providers_by_type = provider_registry[api]
provider_spec = providers_by_type.get(provider)
if not provider_spec:
provider_spec = providers_by_type.get(f"inline::{provider}")
if not provider_spec:
provider_spec = providers_by_type.get(f"remote::{provider}")
if not provider_spec:
raise ValueError(
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
)
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_configs_by_api[api_str] = [
Provider(
provider_id=provider,
provider_type=provider_spec.provider_type,
config=provider_config,
)
]
config = StackRunConfig(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
providers=provider_configs_by_api,
)
return config