mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-19 15:38:52 +00:00
Refactor safety shield fixtures
This commit is contained in:
parent
743da9690b
commit
8645f8bc9e
6 changed files with 53 additions and 62 deletions
|
@ -66,14 +66,14 @@ def pytest_configure(config):
|
|||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--safety-model",
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the safety model to use for testing",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
|
||||
|
||||
SAFETY_MODEL_PARAMS = [
|
||||
SAFETY_SHIELD_PARAMS = [
|
||||
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||
]
|
||||
|
||||
|
@ -83,13 +83,13 @@ def pytest_generate_tests(metafunc):
|
|||
# But a user can also pass in a custom combination via the CLI by doing
|
||||
# `--providers inference=together,safety=meta_reference`
|
||||
|
||||
if "safety_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--safety-model")
|
||||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
if "safety_shield" in metafunc.fixturenames:
|
||||
shield_id = metafunc.config.getoption("--safety-shield")
|
||||
if shield_id:
|
||||
params = [pytest.param(shield_id, id="")]
|
||||
else:
|
||||
params = SAFETY_MODEL_PARAMS
|
||||
for fixture in ["inference_model", "safety_model"]:
|
||||
params = SAFETY_SHIELD_PARAMS
|
||||
for fixture in ["inference_model", "safety_shield"]:
|
||||
metafunc.parametrize(
|
||||
fixture,
|
||||
params,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue