mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
Refactor safety shield fixtures
This commit is contained in:
parent
743da9690b
commit
8645f8bc9e
6 changed files with 53 additions and 62 deletions
|
@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides
|
|||
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||
from .fixtures import AGENTS_FIXTURES
|
||||
|
||||
|
||||
|
@ -75,28 +75,30 @@ def pytest_addoption(parser):
|
|||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-model",
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default="Llama-Guard-3-8B",
|
||||
help="Specify the safety model to use for testing",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
safety_model = metafunc.config.getoption("--safety-model")
|
||||
if "safety_model" in metafunc.fixturenames:
|
||||
shield_id = metafunc.config.getoption("--safety-shield")
|
||||
if "safety_shield" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"safety_model",
|
||||
[pytest.param(safety_model, id="")],
|
||||
"safety_shield",
|
||||
[pytest.param(shield_id, id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
inference_model = metafunc.config.getoption("--inference-model")
|
||||
models = list(set({inference_model, safety_model}))
|
||||
models = set({inference_model})
|
||||
if safety_model := safety_model_from_shield(shield_id):
|
||||
models.add(safety_model)
|
||||
|
||||
metafunc.parametrize(
|
||||
"inference_model",
|
||||
[pytest.param(models, id="")],
|
||||
[pytest.param(list(models), id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "agents_stack" in metafunc.fixturenames:
|
||||
|
|
|
@ -19,7 +19,6 @@ from llama_stack.providers.inline.agents.meta_reference import (
|
|||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..safety.fixtures import get_shield_to_register
|
||||
|
||||
|
||||
def pick_inference_model(inference_model):
|
||||
|
@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_stack(request, inference_model, safety_model):
|
||||
async def agents_stack(request, inference_model, safety_shield):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
|
@ -71,9 +70,6 @@ async def agents_stack(request, inference_model, safety_model):
|
|||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
shield_input = get_shield_to_register(
|
||||
providers["safety"][0].provider_type, safety_model
|
||||
)
|
||||
inference_models = (
|
||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||
)
|
||||
|
@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model):
|
|||
)
|
||||
for model in inference_models
|
||||
],
|
||||
shields=[shield_input],
|
||||
shields=[safety_shield],
|
||||
)
|
||||
return impls[Api.agents], impls[Api.memory]
|
||||
|
|
|
@ -82,7 +82,7 @@ async def create_agent_session(agents_impl, agent_config):
|
|||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turns_with_safety(
|
||||
self, safety_model, agents_stack, common_params
|
||||
self, safety_shield, agents_stack, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agent_id, session_id = await create_agent_session(
|
||||
|
@ -90,8 +90,8 @@ class TestAgents:
|
|||
AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"input_shields": [safety_model],
|
||||
"output_shields": [safety_model],
|
||||
"input_shields": [safety_shield.shield_id],
|
||||
"output_shields": [safety_shield.shield_id],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
|
|
@ -66,14 +66,14 @@ def pytest_configure(config):
|
|||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--safety-model",
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the safety model to use for testing",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
|
||||
|
||||
SAFETY_MODEL_PARAMS = [
|
||||
SAFETY_SHIELD_PARAMS = [
|
||||
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||
]
|
||||
|
||||
|
@ -83,13 +83,13 @@ def pytest_generate_tests(metafunc):
|
|||
# But a user can also pass in a custom combination via the CLI by doing
|
||||
# `--providers inference=together,safety=meta_reference`
|
||||
|
||||
if "safety_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--safety-model")
|
||||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
if "safety_shield" in metafunc.fixturenames:
|
||||
shield_id = metafunc.config.getoption("--safety-shield")
|
||||
if shield_id:
|
||||
params = [pytest.param(shield_id, id="")]
|
||||
else:
|
||||
params = SAFETY_MODEL_PARAMS
|
||||
for fixture in ["inference_model", "safety_model"]:
|
||||
params = SAFETY_SHIELD_PARAMS
|
||||
for fixture in ["inference_model", "safety_shield"]:
|
||||
metafunc.parametrize(
|
||||
fixture,
|
||||
params,
|
||||
|
|
|
@ -27,19 +27,38 @@ def safety_remote() -> ProviderFixture:
|
|||
return remote_stack_fixture()
|
||||
|
||||
|
||||
def safety_model_from_shield(shield_id):
|
||||
if shield_id in ("Bedrock", "CodeScanner", "CodeShield"):
|
||||
return None
|
||||
|
||||
return shield_id
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_model(request):
|
||||
def safety_shield(request):
|
||||
if hasattr(request, "param"):
|
||||
return request.param
|
||||
return request.config.getoption("--safety-model", None)
|
||||
shield_id = request.param
|
||||
else:
|
||||
shield_id = request.config.getoption("--safety-shield", None)
|
||||
|
||||
if shield_id == "bedrock":
|
||||
shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
||||
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
||||
else:
|
||||
params = {}
|
||||
|
||||
return ShieldInput(
|
||||
shield_id=shield_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_llama_guard(safety_model) -> ProviderFixture:
|
||||
def safety_llama_guard() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="inline::llama-guard",
|
||||
provider_id="llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
config=LlamaGuardConfig().model_dump(),
|
||||
)
|
||||
|
@ -55,7 +74,7 @@ def safety_prompt_guard() -> ProviderFixture:
|
|||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="inline::prompt-guard",
|
||||
provider_id="prompt-guard",
|
||||
provider_type="inline::prompt-guard",
|
||||
config=PromptGuardConfig().model_dump(),
|
||||
)
|
||||
|
@ -80,7 +99,7 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def safety_stack(inference_model, safety_model, request):
|
||||
async def safety_stack(inference_model, safety_shield, request):
|
||||
# We need an inference + safety fixture to test safety
|
||||
fixture_dict = request.param
|
||||
inference_fixture = request.getfixturevalue(
|
||||
|
@ -98,32 +117,13 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
if safety_fixture.provider_data:
|
||||
provider_data.update(safety_fixture.provider_data)
|
||||
|
||||
shield_provider_type = safety_fixture.providers[0].provider_type
|
||||
shield_input = get_shield_to_register(shield_provider_type, safety_model)
|
||||
|
||||
print(f"inference_model: {inference_model}")
|
||||
print(f"shield_input = {shield_input}")
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.safety, Api.shields, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[ModelInput(model_id=inference_model)],
|
||||
shields=[shield_input],
|
||||
shields=[safety_shield],
|
||||
)
|
||||
|
||||
shield = await impls[Api.shields].get_shield(shield_input.shield_id)
|
||||
shield = await impls[Api.shields].get_shield(safety_shield.shield_id)
|
||||
return impls[Api.safety], impls[Api.shields], shield
|
||||
|
||||
|
||||
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
|
||||
if provider_type == "remote::bedrock":
|
||||
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
||||
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
||||
else:
|
||||
params = {}
|
||||
identifier = safety_model
|
||||
|
||||
return ShieldInput(
|
||||
shield_id=identifier,
|
||||
params=params,
|
||||
)
|
||||
|
|
|
@ -18,13 +18,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
|||
|
||||
|
||||
class TestSafety:
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_shield(self, safety_stack):
|
||||
_, shields_impl, shield = safety_stack
|
||||
assert shield is not None
|
||||
assert shield.provider_resource_id == shield.identifier
|
||||
assert shield.provider_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shield_list(self, safety_stack):
|
||||
_, shields_impl, _ = safety_stack
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue