diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index b5e0db859..98b19616c 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -73,7 +73,7 @@ def pytest_configure(config): SAFETY_SHIELD_PARAMS = [ - pytest.param("meta-llama/Llama3.1-70B-Instruct", marks=pytest.mark.guard_1b, id="guard_1b"), + pytest.param("meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"), ] @@ -85,7 +85,7 @@ def pytest_generate_tests(metafunc): if "safety_shield" in metafunc.fixturenames: shield_id = metafunc.config.getoption("--safety-shield") if shield_id: - # assert shield_id.startswith("meta-llama/") + assert shield_id.startswith("meta-llama/") params = [pytest.param(shield_id, id="")] else: params = SAFETY_SHIELD_PARAMS