mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
support multiple model ids for testing
This commit is contained in:
parent
113b17679d
commit
c19350f4ed
4 changed files with 59 additions and 54 deletions
|
@ -3,5 +3,3 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Make tests directory a Python package
|
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import inspect
|
||||||
|
import itertools
|
||||||
import os
|
import os
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
@ -49,30 +51,28 @@ def pytest_addoption(parser):
|
||||||
parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value")
|
parser.addoption("--env", action="append", help="Set environment variables, e.g. --env KEY=value")
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--text-model",
|
"--text-model",
|
||||||
help="Specify the text model to use for testing. Fixture name: text_model_id",
|
help="comma-separated list of text models. Fixture name: text_model_id",
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--vision-model",
|
"--vision-model",
|
||||||
help="Specify the vision model to use for testing. Fixture name: vision_model_id",
|
help="comma-separated list of vision models. Fixture name: vision_model_id",
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--embedding-model",
|
"--embedding-model",
|
||||||
help="Specify the embedding model to use for testing. Fixture name: embedding_model_id",
|
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--safety-shield",
|
"--safety-shield",
|
||||||
default="meta-llama/Llama-Guard-3-1B",
|
help="comma-separated list of safety shields. Fixture name: shield_id",
|
||||||
help="Specify the safety shield model to use for testing",
|
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--judge-model",
|
"--judge-model",
|
||||||
help="Specify the judge model to use for testing. Fixture name: judge_model_id",
|
help="comma-separated list of judge models. Fixture name: judge_model_id",
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--embedding-dimension",
|
"--embedding-dimension",
|
||||||
type=int,
|
type=int,
|
||||||
default=384,
|
help="Output dimensionality of the embedding model to use for testing. Default: 384",
|
||||||
help="Output dimensionality of the embedding model to use for testing",
|
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--record-responses",
|
"--record-responses",
|
||||||
|
@ -104,56 +104,65 @@ def get_short_id(value):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
|
"""
|
||||||
|
This is the main function which processes CLI arguments and generates various combinations of parameters.
|
||||||
|
It is also responsible for generating test IDs which are succinct enough.
|
||||||
|
|
||||||
|
Each option can be comma separated list of values which results in multiple parameter combinations.
|
||||||
|
"""
|
||||||
params = []
|
params = []
|
||||||
values = []
|
param_values = {}
|
||||||
id_parts = []
|
id_parts = []
|
||||||
|
|
||||||
if "text_model_id" in metafunc.fixturenames:
|
# Map of fixture name to its CLI option and ID prefix
|
||||||
params.append("text_model_id")
|
fixture_configs = {
|
||||||
val = metafunc.config.getoption("--text-model")
|
"text_model_id": ("--text-model", "txt"),
|
||||||
values.append(val)
|
"vision_model_id": ("--vision-model", "vis"),
|
||||||
|
"embedding_model_id": ("--embedding-model", "emb"),
|
||||||
|
"shield_id": ("--safety-shield", "shield"),
|
||||||
|
"judge_model_id": ("--judge-model", "judge"),
|
||||||
|
"embedding_dimension": ("--embedding-dimension", "dim"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Collect all parameters and their values
|
||||||
|
for fixture_name, (option, id_prefix) in fixture_configs.items():
|
||||||
|
if fixture_name not in metafunc.fixturenames:
|
||||||
|
continue
|
||||||
|
|
||||||
|
params.append(fixture_name)
|
||||||
|
val = metafunc.config.getoption(option)
|
||||||
|
|
||||||
|
values = [v.strip() for v in str(val).split(",")] if val else [None]
|
||||||
|
param_values[fixture_name] = values
|
||||||
if val:
|
if val:
|
||||||
id_parts.append(f"txt={get_short_id(val)}")
|
id_parts.extend(f"{id_prefix}={get_short_id(v)}" for v in values)
|
||||||
|
|
||||||
if "vision_model_id" in metafunc.fixturenames:
|
if not params:
|
||||||
params.append("vision_model_id")
|
return
|
||||||
val = metafunc.config.getoption("--vision-model")
|
|
||||||
values.append(val)
|
|
||||||
if val:
|
|
||||||
id_parts.append(f"vis={get_short_id(val)}")
|
|
||||||
|
|
||||||
if "embedding_model_id" in metafunc.fixturenames:
|
# Generate all combinations of parameter values
|
||||||
params.append("embedding_model_id")
|
value_combinations = list(itertools.product(*[param_values[p] for p in params]))
|
||||||
val = metafunc.config.getoption("--embedding-model")
|
|
||||||
values.append(val)
|
|
||||||
if val:
|
|
||||||
id_parts.append(f"emb={get_short_id(val)}")
|
|
||||||
|
|
||||||
if "shield_id" in metafunc.fixturenames:
|
# Generate test IDs
|
||||||
params.append("shield_id")
|
test_ids = []
|
||||||
val = metafunc.config.getoption("--safety-shield")
|
non_empty_params = [(i, values) for i, values in enumerate(param_values.values()) if values[0] is not None]
|
||||||
values.append(val)
|
|
||||||
if val:
|
|
||||||
id_parts.append(f"shield={get_short_id(val)}")
|
|
||||||
|
|
||||||
if "judge_model_id" in metafunc.fixturenames:
|
# Get actual function parameters using inspect
|
||||||
params.append("judge_model_id")
|
test_func_params = set(inspect.signature(metafunc.function).parameters.keys())
|
||||||
val = metafunc.config.getoption("--judge-model")
|
|
||||||
values.append(val)
|
|
||||||
if val:
|
|
||||||
id_parts.append(f"judge={get_short_id(val)}")
|
|
||||||
|
|
||||||
if "embedding_dimension" in metafunc.fixturenames:
|
if non_empty_params:
|
||||||
params.append("embedding_dimension")
|
# For each combination, build an ID from the non-None parameters
|
||||||
val = metafunc.config.getoption("--embedding-dimension")
|
for combo in value_combinations:
|
||||||
values.append(val)
|
parts = []
|
||||||
if val != 384:
|
for param_name, val in zip(params, combo, strict=True):
|
||||||
id_parts.append(f"dim={val}")
|
# Only include if parameter is in test function signature and value is meaningful
|
||||||
|
if param_name in test_func_params and val:
|
||||||
|
prefix = fixture_configs[param_name][1] # Get the ID prefix
|
||||||
|
parts.append(f"{prefix}={get_short_id(val)}")
|
||||||
|
if parts:
|
||||||
|
test_ids.append(":".join(parts))
|
||||||
|
|
||||||
if params:
|
metafunc.parametrize(params, value_combinations, scope="session", ids=test_ids if test_ids else None)
|
||||||
# Create a single test ID string
|
|
||||||
test_id = ":".join(id_parts)
|
|
||||||
metafunc.parametrize(params, [values], scope="session", ids=[test_id])
|
|
||||||
|
|
||||||
|
|
||||||
pytest_plugins = ["tests.integration.fixtures.common"]
|
pytest_plugins = ["tests.integration.fixtures.common"]
|
||||||
|
|
|
@ -3,5 +3,3 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Make fixtures directory a Python package
|
|
||||||
|
|
|
@ -133,7 +133,7 @@ def client_with_models(
|
||||||
if judge_model_id and judge_model_id not in model_ids:
|
if judge_model_id and judge_model_id not in model_ids:
|
||||||
client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
|
client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
|
||||||
|
|
||||||
if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids:
|
if embedding_model_id and embedding_model_id not in model_ids:
|
||||||
# try to find a provider that supports embeddings, if sentence-transformers is not available
|
# try to find a provider that supports embeddings, if sentence-transformers is not available
|
||||||
selected_provider = None
|
selected_provider = None
|
||||||
for p in providers:
|
for p in providers:
|
||||||
|
@ -146,7 +146,7 @@ def client_with_models(
|
||||||
model_id=embedding_model_id,
|
model_id=embedding_model_id,
|
||||||
provider_id=selected_provider.provider_id,
|
provider_id=selected_provider.provider_id,
|
||||||
model_type="embedding",
|
model_type="embedding",
|
||||||
metadata={"embedding_dimension": embedding_dimension},
|
metadata={"embedding_dimension": embedding_dimension or 384},
|
||||||
)
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue