mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? - revamp and clean up datasets/scoring/eval integration tests - closes https://github.com/meta-llama/llama-stack/issues/1396 [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan **dataset** ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/integration/datasetio/ ``` <img width="842" alt="image" src="https://github.com/user-attachments/assets/88fc2b6a-b496-47bf-bc0c-8fea48ba36ff" /> **scoring** ``` LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/scoring --text-model meta-llama/Llama-3.1-8B-Instruct --judge-model meta-llama/Llama-3.1-8B-Instruct ``` <img width="851" alt="image" src="https://github.com/user-attachments/assets/50f46415-b44c-4c37-a6c3-076f2767adb3" /> **eval** ``` LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/eval --text-model meta-llama/Llama-3.1-8B-Instruct --judge-model meta-llama/Llama-3.1-8B-Instruct ``` <img width="841" alt="image" src="https://github.com/user-attachments/assets/8eb1c65c-3b39-4d66-8ff4-f471ca783e49" /> [//]: # (## Documentation)
160 lines
5.4 KiB
Python
160 lines
5.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import inspect
|
|
import itertools
|
|
import os
|
|
import textwrap
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from .report import Report
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.option.tbstyle = "short"
|
|
config.option.disable_warnings = True
|
|
|
|
load_dotenv()
|
|
|
|
env_vars = config.getoption("--env") or []
|
|
for env_var in env_vars:
|
|
key, value = env_var.split("=", 1)
|
|
os.environ[key] = value
|
|
|
|
if config.getoption("--report"):
|
|
config.pluginmanager.register(Report(config))
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
parser.addoption(
|
|
"--stack-config",
|
|
help=textwrap.dedent(
|
|
"""
|
|
a 'pointer' to the stack. this can be either be:
|
|
(a) a template name like `fireworks`, or
|
|
(b) a path to a run.yaml file, or
|
|
(c) an adhoc config spec, e.g. `inference=fireworks,safety=llama-guard,agents=meta-reference`
|
|
"""
|
|
),
|
|
)
|
|
parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value")
|
|
parser.addoption(
|
|
"--text-model",
|
|
help="comma-separated list of text models. Fixture name: text_model_id",
|
|
)
|
|
parser.addoption(
|
|
"--vision-model",
|
|
help="comma-separated list of vision models. Fixture name: vision_model_id",
|
|
)
|
|
parser.addoption(
|
|
"--embedding-model",
|
|
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
|
|
)
|
|
parser.addoption(
|
|
"--safety-shield",
|
|
help="comma-separated list of safety shields. Fixture name: shield_id",
|
|
)
|
|
parser.addoption(
|
|
"--judge-model",
|
|
help="Specify the judge model to use for testing",
|
|
)
|
|
parser.addoption(
|
|
"--embedding-dimension",
|
|
type=int,
|
|
help="Output dimensionality of the embedding model to use for testing. Default: 384",
|
|
)
|
|
parser.addoption(
|
|
"--record-responses",
|
|
action="store_true",
|
|
help="Record new API responses instead of using cached ones.",
|
|
)
|
|
parser.addoption(
|
|
"--report",
|
|
help="Path where the test report should be written, e.g. --report=/path/to/report.md",
|
|
)
|
|
|
|
|
|
MODEL_SHORT_IDS = {
|
|
"meta-llama/Llama-3.2-3B-Instruct": "3B",
|
|
"meta-llama/Llama-3.1-8B-Instruct": "8B",
|
|
"meta-llama/Llama-3.1-70B-Instruct": "70B",
|
|
"meta-llama/Llama-3.1-405B-Instruct": "405B",
|
|
"meta-llama/Llama-3.2-11B-Vision-Instruct": "11B",
|
|
"meta-llama/Llama-3.2-90B-Vision-Instruct": "90B",
|
|
"meta-llama/Llama-3.3-70B-Instruct": "70B",
|
|
"meta-llama/Llama-Guard-3-1B": "Guard1B",
|
|
"meta-llama/Llama-Guard-3-8B": "Guard8B",
|
|
"all-MiniLM-L6-v2": "MiniLM",
|
|
}
|
|
|
|
|
|
def get_short_id(value):
|
|
return MODEL_SHORT_IDS.get(value, value)
|
|
|
|
|
|
def pytest_generate_tests(metafunc):
|
|
"""
|
|
This is the main function which processes CLI arguments and generates various combinations of parameters.
|
|
It is also responsible for generating test IDs which are succinct enough.
|
|
|
|
Each option can be comma separated list of values which results in multiple parameter combinations.
|
|
"""
|
|
params = []
|
|
param_values = {}
|
|
id_parts = []
|
|
|
|
# Map of fixture name to its CLI option and ID prefix
|
|
fixture_configs = {
|
|
"text_model_id": ("--text-model", "txt"),
|
|
"vision_model_id": ("--vision-model", "vis"),
|
|
"embedding_model_id": ("--embedding-model", "emb"),
|
|
"shield_id": ("--safety-shield", "shield"),
|
|
"judge_model_id": ("--judge-model", "judge"),
|
|
"embedding_dimension": ("--embedding-dimension", "dim"),
|
|
}
|
|
|
|
# Collect all parameters and their values
|
|
for fixture_name, (option, id_prefix) in fixture_configs.items():
|
|
if fixture_name not in metafunc.fixturenames:
|
|
continue
|
|
|
|
params.append(fixture_name)
|
|
val = metafunc.config.getoption(option)
|
|
|
|
values = [v.strip() for v in str(val).split(",")] if val else [None]
|
|
param_values[fixture_name] = values
|
|
if val:
|
|
id_parts.extend(f"{id_prefix}={get_short_id(v)}" for v in values)
|
|
|
|
if not params:
|
|
return
|
|
|
|
# Generate all combinations of parameter values
|
|
value_combinations = list(itertools.product(*[param_values[p] for p in params]))
|
|
|
|
# Generate test IDs
|
|
test_ids = []
|
|
non_empty_params = [(i, values) for i, values in enumerate(param_values.values()) if values[0] is not None]
|
|
|
|
# Get actual function parameters using inspect
|
|
test_func_params = set(inspect.signature(metafunc.function).parameters.keys())
|
|
|
|
if non_empty_params:
|
|
# For each combination, build an ID from the non-None parameters
|
|
for combo in value_combinations:
|
|
parts = []
|
|
for param_name, val in zip(params, combo, strict=True):
|
|
# Only include if parameter is in test function signature and value is meaningful
|
|
if param_name in test_func_params and val:
|
|
prefix = fixture_configs[param_name][1] # Get the ID prefix
|
|
parts.append(f"{prefix}={get_short_id(val)}")
|
|
if parts:
|
|
test_ids.append(":".join(parts))
|
|
|
|
metafunc.parametrize(params, value_combinations, scope="session", ids=test_ids if test_ids else None)
|
|
|
|
|
|
pytest_plugins = ["tests.integration.fixtures.common"]
|