support multiple model ids for testing

This commit is contained in:
Ashwin Bharambe 2025-03-05 11:21:53 -08:00
parent 113b17679d
commit c19350f4ed
4 changed files with 59 additions and 54 deletions

View file

@ -3,5 +3,3 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Make tests directory a Python package

View file

@ -3,6 +3,8 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import inspect
import itertools
import os import os
import textwrap import textwrap
@ -49,30 +51,28 @@ def pytest_addoption(parser):
parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value") parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value")
parser.addoption( parser.addoption(
"--text-model", "--text-model",
help="Specify the text model to use for testing. Fixture name: text_model_id", help="comma-separated list of text models. Fixture name: text_model_id",
) )
parser.addoption( parser.addoption(
"--vision-model", "--vision-model",
help="Specify the vision model to use for testing. Fixture name: vision_model_id", help="comma-separated list of vision models. Fixture name: vision_model_id",
) )
parser.addoption( parser.addoption(
"--embedding-model", "--embedding-model",
help="Specify the embedding model to use for testing. Fixture name: embedding_model_id", help="comma-separated list of embedding models. Fixture name: embedding_model_id",
) )
parser.addoption( parser.addoption(
"--safety-shield", "--safety-shield",
default="meta-llama/Llama-Guard-3-1B", help="comma-separated list of safety shields. Fixture name: shield_id",
help="Specify the safety shield model to use for testing",
) )
parser.addoption( parser.addoption(
"--judge-model", "--judge-model",
help="Specify the judge model to use for testing. Fixture name: judge_model_id", help="comma-separated list of judge models. Fixture name: judge_model_id",
) )
parser.addoption( parser.addoption(
"--embedding-dimension", "--embedding-dimension",
type=int, type=int,
default=384, help="Output dimensionality of the embedding model to use for testing. Default: 384",
help="Output dimensionality of the embedding model to use for testing",
) )
parser.addoption( parser.addoption(
"--record-responses", "--record-responses",
@ -104,56 +104,65 @@ def get_short_id(value):
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
"""
This is the main function which processes CLI arguments and generates various combinations of parameters.
It is also responsible for generating test IDs which are succinct enough.
Each option can be comma separated list of values which results in multiple parameter combinations.
"""
params = [] params = []
values = [] param_values = {}
id_parts = [] id_parts = []
if "text_model_id" in metafunc.fixturenames: # Map of fixture name to its CLI option and ID prefix
params.append("text_model_id") fixture_configs = {
val = metafunc.config.getoption("--text-model") "text_model_id": ("--text-model", "txt"),
values.append(val) "vision_model_id": ("--vision-model", "vis"),
"embedding_model_id": ("--embedding-model", "emb"),
"shield_id": ("--safety-shield", "shield"),
"judge_model_id": ("--judge-model", "judge"),
"embedding_dimension": ("--embedding-dimension", "dim"),
}
# Collect all parameters and their values
for fixture_name, (option, id_prefix) in fixture_configs.items():
if fixture_name not in metafunc.fixturenames:
continue
params.append(fixture_name)
val = metafunc.config.getoption(option)
values = [v.strip() for v in str(val).split(",")] if val else [None]
param_values[fixture_name] = values
if val: if val:
id_parts.append(f"txt={get_short_id(val)}") id_parts.extend(f"{id_prefix}={get_short_id(v)}" for v in values)
if "vision_model_id" in metafunc.fixturenames: if not params:
params.append("vision_model_id") return
val = metafunc.config.getoption("--vision-model")
values.append(val)
if val:
id_parts.append(f"vis={get_short_id(val)}")
if "embedding_model_id" in metafunc.fixturenames: # Generate all combinations of parameter values
params.append("embedding_model_id") value_combinations = list(itertools.product(*[param_values[p] for p in params]))
val = metafunc.config.getoption("--embedding-model")
values.append(val)
if val:
id_parts.append(f"emb={get_short_id(val)}")
if "shield_id" in metafunc.fixturenames: # Generate test IDs
params.append("shield_id") test_ids = []
val = metafunc.config.getoption("--safety-shield") non_empty_params = [(i, values) for i, values in enumerate(param_values.values()) if values[0] is not None]
values.append(val)
if val:
id_parts.append(f"shield={get_short_id(val)}")
if "judge_model_id" in metafunc.fixturenames: # Get actual function parameters using inspect
params.append("judge_model_id") test_func_params = set(inspect.signature(metafunc.function).parameters.keys())
val = metafunc.config.getoption("--judge-model")
values.append(val)
if val:
id_parts.append(f"judge={get_short_id(val)}")
if "embedding_dimension" in metafunc.fixturenames: if non_empty_params:
params.append("embedding_dimension") # For each combination, build an ID from the non-None parameters
val = metafunc.config.getoption("--embedding-dimension") for combo in value_combinations:
values.append(val) parts = []
if val != 384: for param_name, val in zip(params, combo, strict=True):
id_parts.append(f"dim={val}") # Only include if parameter is in test function signature and value is meaningful
if param_name in test_func_params and val:
prefix = fixture_configs[param_name][1] # Get the ID prefix
parts.append(f"{prefix}={get_short_id(val)}")
if parts:
test_ids.append(":".join(parts))
if params: metafunc.parametrize(params, value_combinations, scope="session", ids=test_ids if test_ids else None)
# Create a single test ID string
test_id = ":".join(id_parts)
metafunc.parametrize(params, [values], scope="session", ids=[test_id])
pytest_plugins = ["tests.integration.fixtures.common"] pytest_plugins = ["tests.integration.fixtures.common"]

View file

@ -3,5 +3,3 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Make fixtures directory a Python package

View file

@ -133,7 +133,7 @@ def client_with_models(
if judge_model_id and judge_model_id not in model_ids: if judge_model_id and judge_model_id not in model_ids:
client.models.register(model_id=judge_model_id, provider_id=inference_providers[0]) client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids: if embedding_model_id and embedding_model_id not in model_ids:
# try to find a provider that supports embeddings, if sentence-transformers is not available # try to find a provider that supports embeddings, if sentence-transformers is not available
selected_provider = None selected_provider = None
for p in providers: for p in providers:
@ -146,7 +146,7 @@ def client_with_models(
model_id=embedding_model_id, model_id=embedding_model_id,
provider_id=selected_provider.provider_id, provider_id=selected_provider.provider_id,
model_type="embedding", model_type="embedding",
metadata={"embedding_dimension": embedding_dimension}, metadata={"embedding_dimension": embedding_dimension or 384},
) )
return client return client