Refactor safety shield fixtures

This commit is contained in:
Ashwin Bharambe 2024-11-12 15:58:58 -08:00
parent 743da9690b
commit 8645f8bc9e
6 changed files with 53 additions and 62 deletions

View file

@ -66,14 +66,14 @@ def pytest_configure(config):
def pytest_addoption(parser):
parser.addoption(
"--safety-model",
"--safety-shield",
action="store",
default=None,
help="Specify the safety model to use for testing",
help="Specify the safety shield to use for testing",
)
SAFETY_MODEL_PARAMS = [
SAFETY_SHIELD_PARAMS = [
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
]
@ -83,13 +83,13 @@ def pytest_generate_tests(metafunc):
# But a user can also pass in a custom combination via the CLI by doing
# `--providers inference=together,safety=meta_reference`
if "safety_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--safety-model")
if model:
params = [pytest.param(model, id="")]
if "safety_shield" in metafunc.fixturenames:
shield_id = metafunc.config.getoption("--safety-shield")
if shield_id:
params = [pytest.param(shield_id, id="")]
else:
params = SAFETY_MODEL_PARAMS
for fixture in ["inference_model", "safety_model"]:
params = SAFETY_SHIELD_PARAMS
for fixture in ["inference_model", "safety_shield"]:
metafunc.parametrize(
fixture,
params,