Refactor safety shield fixtures

This commit is contained in:
Ashwin Bharambe 2024-11-12 15:58:58 -08:00
parent 743da9690b
commit 8645f8bc9e
6 changed files with 53 additions and 62 deletions

View file

@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from .fixtures import AGENTS_FIXTURES
@ -75,28 +75,30 @@ def pytest_addoption(parser):
help="Specify the inference model to use for testing",
)
parser.addoption(
"--safety-model",
"--safety-shield",
action="store",
default="Llama-Guard-3-8B",
help="Specify the safety model to use for testing",
help="Specify the safety shield to use for testing",
)
def pytest_generate_tests(metafunc):
safety_model = metafunc.config.getoption("--safety-model")
if "safety_model" in metafunc.fixturenames:
shield_id = metafunc.config.getoption("--safety-shield")
if "safety_shield" in metafunc.fixturenames:
metafunc.parametrize(
"safety_model",
[pytest.param(safety_model, id="")],
"safety_shield",
[pytest.param(shield_id, id="")],
indirect=True,
)
if "inference_model" in metafunc.fixturenames:
inference_model = metafunc.config.getoption("--inference-model")
models = list(set({inference_model, safety_model}))
models = set({inference_model})
if safety_model := safety_model_from_shield(shield_id):
models.add(safety_model)
metafunc.parametrize(
"inference_model",
[pytest.param(models, id="")],
[pytest.param(list(models), id="")],
indirect=True,
)
if "agents_stack" in metafunc.fixturenames:

View file

@ -19,7 +19,6 @@ from llama_stack.providers.inline.agents.meta_reference import (
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..safety.fixtures import get_shield_to_register
def pick_inference_model(inference_model):
@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def agents_stack(request, inference_model, safety_model):
async def agents_stack(request, inference_model, safety_shield):
fixture_dict = request.param
providers = {}
@ -71,9 +70,6 @@ async def agents_stack(request, inference_model, safety_model):
if fixture.provider_data:
provider_data.update(fixture.provider_data)
shield_input = get_shield_to_register(
providers["safety"][0].provider_type, safety_model
)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model):
)
for model in inference_models
],
shields=[shield_input],
shields=[safety_shield],
)
return impls[Api.agents], impls[Api.memory]

View file

@ -82,7 +82,7 @@ async def create_agent_session(agents_impl, agent_config):
class TestAgents:
@pytest.mark.asyncio
async def test_agent_turns_with_safety(
self, safety_model, agents_stack, common_params
self, safety_shield, agents_stack, common_params
):
agents_impl, _ = agents_stack
agent_id, session_id = await create_agent_session(
@ -90,8 +90,8 @@ class TestAgents:
AgentConfig(
**{
**common_params,
"input_shields": [safety_model],
"output_shields": [safety_model],
"input_shields": [safety_shield.shield_id],
"output_shields": [safety_shield.shield_id],
}
),
)

View file

@ -66,14 +66,14 @@ def pytest_configure(config):
def pytest_addoption(parser):
parser.addoption(
"--safety-model",
"--safety-shield",
action="store",
default=None,
help="Specify the safety model to use for testing",
help="Specify the safety shield to use for testing",
)
SAFETY_MODEL_PARAMS = [
SAFETY_SHIELD_PARAMS = [
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
]
@ -83,13 +83,13 @@ def pytest_generate_tests(metafunc):
# But a user can also pass in a custom combination via the CLI by doing
# `--providers inference=together,safety=meta_reference`
if "safety_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--safety-model")
if model:
params = [pytest.param(model, id="")]
if "safety_shield" in metafunc.fixturenames:
shield_id = metafunc.config.getoption("--safety-shield")
if shield_id:
params = [pytest.param(shield_id, id="")]
else:
params = SAFETY_MODEL_PARAMS
for fixture in ["inference_model", "safety_model"]:
params = SAFETY_SHIELD_PARAMS
for fixture in ["inference_model", "safety_shield"]:
metafunc.parametrize(
fixture,
params,

View file

@ -27,19 +27,38 @@ def safety_remote() -> ProviderFixture:
return remote_stack_fixture()
def safety_model_from_shield(shield_id):
if shield_id in ("Bedrock", "CodeScanner", "CodeShield"):
return None
return shield_id
@pytest.fixture(scope="session")
def safety_model(request):
def safety_shield(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--safety-model", None)
shield_id = request.param
else:
shield_id = request.config.getoption("--safety-shield", None)
if shield_id == "bedrock":
shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
else:
params = {}
return ShieldInput(
shield_id=shield_id,
params=params,
)
@pytest.fixture(scope="session")
def safety_llama_guard(safety_model) -> ProviderFixture:
def safety_llama_guard() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="inline::llama-guard",
provider_id="llama-guard",
provider_type="inline::llama-guard",
config=LlamaGuardConfig().model_dump(),
)
@ -55,7 +74,7 @@ def safety_prompt_guard() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="inline::prompt-guard",
provider_id="prompt-guard",
provider_type="inline::prompt-guard",
config=PromptGuardConfig().model_dump(),
)
@ -80,7 +99,7 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
@pytest_asyncio.fixture(scope="session")
async def safety_stack(inference_model, safety_model, request):
async def safety_stack(inference_model, safety_shield, request):
# We need an inference + safety fixture to test safety
fixture_dict = request.param
inference_fixture = request.getfixturevalue(
@ -98,32 +117,13 @@ async def safety_stack(inference_model, safety_model, request):
if safety_fixture.provider_data:
provider_data.update(safety_fixture.provider_data)
shield_provider_type = safety_fixture.providers[0].provider_type
shield_input = get_shield_to_register(shield_provider_type, safety_model)
print(f"inference_model: {inference_model}")
print(f"shield_input = {shield_input}")
impls = await resolve_impls_for_test_v2(
[Api.safety, Api.shields, Api.inference],
providers,
provider_data,
models=[ModelInput(model_id=inference_model)],
shields=[shield_input],
shields=[safety_shield],
)
shield = await impls[Api.shields].get_shield(shield_input.shield_id)
shield = await impls[Api.shields].get_shield(safety_shield.shield_id)
return impls[Api.safety], impls[Api.shields], shield
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
if provider_type == "remote::bedrock":
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
else:
params = {}
identifier = safety_model
return ShieldInput(
shield_id=identifier,
params=params,
)

View file

@ -18,13 +18,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403
class TestSafety:
@pytest.mark.asyncio
async def test_new_shield(self, safety_stack):
_, shields_impl, shield = safety_stack
assert shield is not None
assert shield.provider_resource_id == shield.identifier
assert shield.provider_id is not None
@pytest.mark.asyncio
async def test_shield_list(self, safety_stack):
_, shields_impl, _ = safety_stack