support multiple model ids for testing

This commit is contained in:
Ashwin Bharambe 2025-03-05 11:21:53 -08:00
parent 113b17679d
commit c19350f4ed
4 changed files with 59 additions and 54 deletions

View file

@ -3,5 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Make tests directory a Python package

View file

@ -3,6 +3,8 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
import itertools
import os
import textwrap
@ -49,30 +51,28 @@ def pytest_addoption(parser):
parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value")
parser.addoption(
"--text-model",
help="Specify the text model to use for testing. Fixture name: text_model_id",
help="comma-separated list of text models. Fixture name: text_model_id",
)
parser.addoption(
"--vision-model",
help="Specify the vision model to use for testing. Fixture name: vision_model_id",
help="comma-separated list of vision models. Fixture name: vision_model_id",
)
parser.addoption(
"--embedding-model",
help="Specify the embedding model to use for testing. Fixture name: embedding_model_id",
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
)
parser.addoption(
"--safety-shield",
default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield model to use for testing",
help="comma-separated list of safety shields. Fixture name: shield_id",
)
parser.addoption(
"--judge-model",
help="Specify the judge model to use for testing. Fixture name: judge_model_id",
help="comma-separated list of judge models. Fixture name: judge_model_id",
)
parser.addoption(
"--embedding-dimension",
type=int,
default=384,
help="Output dimensionality of the embedding model to use for testing",
help="Output dimensionality of the embedding model to use for testing. Default: 384",
)
parser.addoption(
"--record-responses",
@ -104,56 +104,65 @@ def get_short_id(value):
def pytest_generate_tests(metafunc):
"""
This is the main function which processes CLI arguments and generates various combinations of parameters.
It is also responsible for generating test IDs which are succinct enough.
Each option can be comma separated list of values which results in multiple parameter combinations.
"""
params = []
values = []
param_values = {}
id_parts = []
if "text_model_id" in metafunc.fixturenames:
params.append("text_model_id")
val = metafunc.config.getoption("--text-model")
values.append(val)
# Map of fixture name to its CLI option and ID prefix
fixture_configs = {
"text_model_id": ("--text-model", "txt"),
"vision_model_id": ("--vision-model", "vis"),
"embedding_model_id": ("--embedding-model", "emb"),
"shield_id": ("--safety-shield", "shield"),
"judge_model_id": ("--judge-model", "judge"),
"embedding_dimension": ("--embedding-dimension", "dim"),
}
# Collect all parameters and their values
for fixture_name, (option, id_prefix) in fixture_configs.items():
if fixture_name not in metafunc.fixturenames:
continue
params.append(fixture_name)
val = metafunc.config.getoption(option)
values = [v.strip() for v in str(val).split(",")] if val else [None]
param_values[fixture_name] = values
if val:
id_parts.append(f"txt={get_short_id(val)}")
id_parts.extend(f"{id_prefix}={get_short_id(v)}" for v in values)
if "vision_model_id" in metafunc.fixturenames:
params.append("vision_model_id")
val = metafunc.config.getoption("--vision-model")
values.append(val)
if val:
id_parts.append(f"vis={get_short_id(val)}")
if not params:
return
if "embedding_model_id" in metafunc.fixturenames:
params.append("embedding_model_id")
val = metafunc.config.getoption("--embedding-model")
values.append(val)
if val:
id_parts.append(f"emb={get_short_id(val)}")
# Generate all combinations of parameter values
value_combinations = list(itertools.product(*[param_values[p] for p in params]))
if "shield_id" in metafunc.fixturenames:
params.append("shield_id")
val = metafunc.config.getoption("--safety-shield")
values.append(val)
if val:
id_parts.append(f"shield={get_short_id(val)}")
# Generate test IDs
test_ids = []
non_empty_params = [(i, values) for i, values in enumerate(param_values.values()) if values[0] is not None]
if "judge_model_id" in metafunc.fixturenames:
params.append("judge_model_id")
val = metafunc.config.getoption("--judge-model")
values.append(val)
if val:
id_parts.append(f"judge={get_short_id(val)}")
# Get actual function parameters using inspect
test_func_params = set(inspect.signature(metafunc.function).parameters.keys())
if "embedding_dimension" in metafunc.fixturenames:
params.append("embedding_dimension")
val = metafunc.config.getoption("--embedding-dimension")
values.append(val)
if val != 384:
id_parts.append(f"dim={val}")
if non_empty_params:
# For each combination, build an ID from the non-None parameters
for combo in value_combinations:
parts = []
for param_name, val in zip(params, combo, strict=True):
# Only include if parameter is in test function signature and value is meaningful
if param_name in test_func_params and val:
prefix = fixture_configs[param_name][1] # Get the ID prefix
parts.append(f"{prefix}={get_short_id(val)}")
if parts:
test_ids.append(":".join(parts))
if params:
# Create a single test ID string
test_id = ":".join(id_parts)
metafunc.parametrize(params, [values], scope="session", ids=[test_id])
metafunc.parametrize(params, value_combinations, scope="session", ids=test_ids if test_ids else None)
pytest_plugins = ["tests.integration.fixtures.common"]

View file

@ -3,5 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Make fixtures directory a Python package

View file

@ -133,7 +133,7 @@ def client_with_models(
if judge_model_id and judge_model_id not in model_ids:
client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids:
if embedding_model_id and embedding_model_id not in model_ids:
# try to find a provider that supports embeddings, if sentence-transformers is not available
selected_provider = None
for p in providers:
@ -146,7 +146,7 @@ def client_with_models(
model_id=embedding_model_id,
provider_id=selected_provider.provider_id,
model_type="embedding",
metadata={"embedding_dimension": embedding_dimension},
metadata={"embedding_dimension": embedding_dimension or 384},
)
return client